Add load balancer

This commit is contained in:
Dany Mahmalat 2022-02-07 14:39:01 -05:00
parent eaf95f6bea
commit 774b9bf02f
7 changed files with 89 additions and 25 deletions

19
api.go
View File

@ -274,9 +274,14 @@ func (a *Api) GetTunnel(tokenData TokenData, params url.Values) (Tunnel, error)
return Tunnel{}, errors.New("Invalid domain parameter") return Tunnel{}, errors.New("Invalid domain parameter")
} }
tun, exists := a.db.GetTunnel(domain) client := params.Get("client")
if client == "" {
return Tunnel{}, errors.New("Invalid client parameter")
}
tun, exists := a.db.GetTunnel(domain + "|" + client)
if !exists { if !exists {
return Tunnel{}, errors.New("Tunnel doesn't exist for domain") return Tunnel{}, errors.New("Tunnel doesn't exist for this domain|client combination")
} }
user, _ := a.db.GetUser(tokenData.Owner) user, _ := a.db.GetUser(tokenData.Owner)
@ -389,6 +394,7 @@ func (a *Api) CreateTunnel(tokenData TokenData, params url.Values) (*Tunnel, err
AuthUsername: username, AuthUsername: username,
AuthPassword: password, AuthPassword: password,
TlsTermination: tlsTerm, TlsTermination: tlsTerm,
Used: false,
} }
tunnel, err := a.tunMan.RequestCreateTunnel(request) tunnel, err := a.tunMan.RequestCreateTunnel(request)
@ -406,7 +412,12 @@ func (a *Api) DeleteTunnel(tokenData TokenData, params url.Values) error {
return errors.New("Invalid domain parameter") return errors.New("Invalid domain parameter")
} }
tun, exists := a.db.GetTunnel(domain) client := params.Get("client")
if client == "" {
return errors.New("Invalid client parameter")
}
tun, exists := a.db.GetTunnel(domain + "|" + client)
if !exists { if !exists {
return errors.New("Tunnel doesn't exist") return errors.New("Tunnel doesn't exist")
} }
@ -418,7 +429,7 @@ func (a *Api) DeleteTunnel(tokenData TokenData, params url.Values) error {
} }
} }
a.tunMan.DeleteTunnel(domain) a.tunMan.DeleteTunnel(domain + "|" + client)
return nil return nil
} }

View File

@ -281,7 +281,7 @@ func Listen() {
} }
} else { } else {
tunnel, exists := db.GetTunnel(hostDomain) tunnel, exists := db.SelectLoadBalancedTunnel(hostDomain)
if !exists { if !exists {
errMessage := fmt.Sprintf("No tunnel attached to %s", hostDomain) errMessage := fmt.Sprintf("No tunnel attached to %s", hostDomain)
w.WriteHeader(500) w.WriteHeader(500)
@ -342,7 +342,7 @@ func (p *Server) handleConnection(clientConn net.Conn) {
passConn := NewProxyConn(clientConn, clientReader) passConn := NewProxyConn(clientConn, clientReader)
tunnel, exists := p.db.GetTunnel(clientHello.ServerName) tunnel, exists := p.db.SelectLoadBalancedTunnel(clientHello.ServerName)
if exists && (tunnel.TlsTermination == "client" || tunnel.TlsTermination == "passthrough") || tunnel.TlsTermination == "client-tls" { if exists && (tunnel.TlsTermination == "client" || tunnel.TlsTermination == "passthrough") || tunnel.TlsTermination == "client-tls" {
p.passthroughRequest(passConn, tunnel) p.passthroughRequest(passConn, tunnel)

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"io/ioutil" "io/ioutil"
"log" "log"
"strings"
"sync" "sync"
"github.com/takingnames/namedrop-go" "github.com/takingnames/namedrop-go"
@ -58,6 +59,7 @@ type Tunnel struct {
ClientName string `json:"client_name"` ClientName string `json:"client_name"`
AuthUsername string `json:"auth_username"` AuthUsername string `json:"auth_username"`
AuthPassword string `json:"auth_password"` AuthPassword string `json:"auth_password"`
Used bool
} }
func NewDatabase() (*Database, error) { func NewDatabase() (*Database, error) {
@ -223,6 +225,44 @@ func (d *Database) GetTunnels() map[string]Tunnel {
return tunnels return tunnels
} }
func (d *Database) SelectLoadBalancedTunnel(domain string) (Tunnel, bool) {
d.mutex.Lock()
defer d.mutex.Unlock()
tunnels := make(map[string]Tunnel)
for dom, tun := range d.Tunnels {
keys := strings.Split(dom, "|")
if domain == keys[0] {
tunnels[dom] = tun
}
}
// Load balance
for _, t := range tunnels {
if !t.Used {
t.Used = true
log.Printf("Routing to : %s|%s", t.Domain, t.ClientName)
return t, true
}
}
// Reset all used flags
count := 1
for _, t := range tunnels {
t.Used = false
if count == len(tunnels) {
t.Used = true
log.Printf("Routing to : %s|%s", t.Domain, t.ClientName)
return t, true
}
count += 1
}
// len(tunnels) is 0
return Tunnel{}, false
}
func (d *Database) GetTunnel(domain string) (Tunnel, bool) { func (d *Database) GetTunnel(domain string) (Tunnel, bool) {
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()

View File

@ -30,8 +30,8 @@
</div> </div>
<div class='button-row'> <div class='button-row'>
<a class='button' href="/tunnel-private-key?domain={{$.Tunnel.Domain}}">Download Private Key</a> <a class='button' href="/tunnel-private-key?domain={{$.Tunnel.Domain}}&client={{$.Tunnel.ClientName}}">Download Private Key</a>
<a class='button' href="/confirm-delete-tunnel?domain={{$.Tunnel.Domain}}">Delete</a> <a class='button' href="/confirm-delete-tunnel?domain={{$.Tunnel.Domain}}&client={{$.Tunnel.ClientName}}">Delete</a>
</div> </div>
{{ template "footer.tmpl" . }} {{ template "footer.tmpl" . }}

View File

@ -4,7 +4,7 @@
<div class='tn-tunnel-list-item'> <div class='tn-tunnel-list-item'>
<div class='tn-attribute'> <div class='tn-attribute'>
<div class='tn-attribute__name'>Domain:</div> <div class='tn-attribute__name'>Domain:</div>
<div class='tn-attribute__value'><a href='https://{{$domain}}'>{{$domain}}</a></div> <div class='tn-attribute__value'><a href='https://{{$tunnel.Domain}}'>{{$tunnel.Domain}}</a></div>
</div> </div>
<div class='tn-attribute'> <div class='tn-attribute'>
<div class='tn-attribute__name'>Client:</div> <div class='tn-attribute__name'>Client:</div>
@ -15,8 +15,8 @@
<div class='tn-attribute__value'>{{$tunnel.ClientAddress}}:{{$tunnel.ClientPort}}</div> <div class='tn-attribute__value'>{{$tunnel.ClientAddress}}:{{$tunnel.ClientPort}}</div>
</div> </div>
<div class='button-row'> <div class='button-row'>
<a class='button' href="/tunnels/{{$domain}}">View</a> <a class='button' href="/tunnels/{{$tunnel.Domain}}/{{$tunnel.ClientName}}">View</a>
<a class='button' href="/confirm-delete-tunnel?domain={{$domain}}">Delete</a> <a class='button' href="/confirm-delete-tunnel?domain={{$tunnel.Domain}}&client={{$tunnel.ClientName}}">Delete</a>
</div> </div>
</div> </div>
{{ end }} {{ end }}
@ -36,14 +36,14 @@
{{range $domain, $tunnel:= .Tunnels}} {{range $domain, $tunnel:= .Tunnels}}
<tr> <tr>
<td class='tn-tunnel-table__cell'> <td class='tn-tunnel-table__cell'>
<a href='https://{{$domain}}'>{{$domain}}</a> <a href='https://{{$tunnel.Domain}}'>{{$tunnel.Domain}}</a>
</td> </td>
<td class='tn-tunnel-table__cell'>{{$tunnel.ClientName}}</td> <td class='tn-tunnel-table__cell'>{{$tunnel.ClientName}}</td>
<td class='tn-tunnel-table__cell'>{{$tunnel.ClientAddress}}:{{$tunnel.ClientPort}}</td> <td class='tn-tunnel-table__cell'>{{$tunnel.ClientAddress}}:{{$tunnel.ClientPort}}</td>
<td class='tn-tunnel-table__cell'> <td class='tn-tunnel-table__cell'>
<div class='button-row'> <div class='button-row'>
<a class='button' href="/tunnels/{{$domain}}">View</a> <a class='button' href="/tunnels/{{$tunnel.Domain}}/{{$tunnel.ClientName}}">View</a>
<a class='button' href="/confirm-delete-tunnel?domain={{$domain}}">Delete</a> <a class='button' href="/confirm-delete-tunnel?domain={{$tunnel.Domain}}&client={{$tunnel.ClientName}}">Delete</a>
</div> </div>
</td> </td>
</tr> </tr>

View File

@ -8,14 +8,15 @@ import (
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
"github.com/caddyserver/certmagic"
"golang.org/x/crypto/ssh"
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
"os/user" "os/user"
"strings" "strings"
"sync" "sync"
"github.com/caddyserver/certmagic"
"golang.org/x/crypto/ssh"
) )
type TunnelManager struct { type TunnelManager struct {
@ -34,9 +35,9 @@ func NewTunnelManager(config *Config, db *Database, certConfig *certmagic.Config
} }
if config.autoCerts { if config.autoCerts {
for domainName, tun := range db.GetTunnels() { for _, tun := range db.GetTunnels() {
if tun.TlsTermination == "server" { if tun.TlsTermination == "server" {
err = certConfig.ManageSync(context.Background(), []string{domainName}) err = certConfig.ManageSync(context.Background(), []string{tun.Domain})
if err != nil { if err != nil {
log.Println("CertMagic error at startup") log.Println("CertMagic error at startup")
log.Println(err) log.Println(err)
@ -84,8 +85,8 @@ func (m *TunnelManager) RequestCreateTunnel(tunReq Tunnel) (Tunnel, error) {
} }
for _, tun := range m.db.GetTunnels() { for _, tun := range m.db.GetTunnels() {
if tunReq.Domain == tun.Domain { if tunReq.Domain == tun.Domain && tunReq.ClientName == tun.ClientName {
return Tunnel{}, errors.New("Tunnel domain already in use") return Tunnel{}, errors.New("Tunnel domain and client name combination already in use")
} }
if tunReq.TunnelPort == tun.TunnelPort { if tunReq.TunnelPort == tun.TunnelPort {
@ -104,7 +105,7 @@ func (m *TunnelManager) RequestCreateTunnel(tunReq Tunnel) (Tunnel, error) {
tunReq.Username = m.user.Username tunReq.Username = m.user.Username
tunReq.TunnelPrivateKey = privKey tunReq.TunnelPrivateKey = privKey
m.db.SetTunnel(tunReq.Domain, tunReq) m.db.SetTunnel(tunReq.Domain+"|"+tunReq.ClientName, tunReq)
return tunReq, nil return tunReq, nil
} }

View File

@ -3,12 +3,15 @@ package boringproxy
import ( import (
"embed" "embed"
"encoding/base64" "encoding/base64"
//"encoding/json" //"encoding/json"
"fmt" "fmt"
qrcode "github.com/skip2/go-qrcode"
"html/template" "html/template"
"io" "io"
"net/http" "net/http"
qrcode "github.com/skip2/go-qrcode"
//"net/url" //"net/url"
//"os" //"os"
"strings" "strings"
@ -142,10 +145,17 @@ func (h *WebUiHandler) handleWebUiRequest(w http.ResponseWriter, r *http.Request
} }
domain := r.Form["domain"][0] domain := r.Form["domain"][0]
if len(r.Form["client"]) != 1 {
w.WriteHeader(400)
w.Write([]byte("Invalid client parameter"))
return
}
client := r.Form["client"][0]
data := &ConfirmData{ data := &ConfirmData{
Head: h.headHtml, Head: h.headHtml,
Message: fmt.Sprintf("Are you sure you want to delete %s?", domain), Message: fmt.Sprintf("Are you sure you want to delete %s|%s?", domain, client),
ConfirmUrl: fmt.Sprintf("/delete-tunnel?domain=%s", domain), ConfirmUrl: fmt.Sprintf("/delete-tunnel?domain=%s&client=%s", domain, client),
CancelUrl: "/tunnels", CancelUrl: "/tunnels",
} }
@ -291,15 +301,17 @@ func (h *WebUiHandler) handleWebUiRequest(w http.ResponseWriter, r *http.Request
parts := strings.Split(r.URL.Path, "/") parts := strings.Split(r.URL.Path, "/")
if len(parts) != 3 { if len(parts) != 4 {
w.WriteHeader(400) w.WriteHeader(400)
h.alertDialog(w, r, "Invalid path", "/tunnels") h.alertDialog(w, r, "Invalid path", "/tunnels")
return return
} }
domain := parts[2] domain := parts[2]
client := parts[3]
r.Form.Set("domain", domain) r.Form.Set("domain", domain)
r.Form.Set("client", client)
tunnel, err := h.api.GetTunnel(tokenData, r.Form) tunnel, err := h.api.GetTunnel(tokenData, r.Form)
if err != nil { if err != nil {