Automatically create admin user on first start

Also changed order that extractToken looks for tokens. Used to
be cookies then headers then query. Now in reverse, to make it
easier to override, ie for replacing cookies during login.
This commit is contained in:
Anders Pitman 2020-10-13 09:48:03 -06:00
parent 8d6e4c2fe8
commit 5cd911f310
5 changed files with 108 additions and 80 deletions

16
auth.go
View File

@ -1,8 +1,6 @@
package main package main
import ( import (
"crypto/rand"
"math/big"
"sync" "sync"
) )
@ -33,17 +31,3 @@ func (a *Auth) Authorized(token string) bool {
return false return false
} }
const chars string = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func genRandomKey() (string, error) {
id := ""
for i := 0; i < 32; i++ {
randIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(chars))))
if err != nil {
return "", err
}
id += string(chars[randIndex.Int64()])
}
return id, nil
}

View File

@ -48,6 +48,17 @@ func Listen() {
log.Fatal(err) log.Fatal(err)
} }
users := db.GetUsers()
if len(users) == 0 {
db.AddUser("admin", true)
token, err := db.AddToken("admin")
if err != nil {
log.Fatal("Failed to initialize admin user")
}
log.Println("Admin token: " + token)
}
certmagic.DefaultACME.DisableHTTPChallenge = true certmagic.DefaultACME.DisableHTTPChallenge = true
//certmagic.DefaultACME.DisableTLSALPNChallenge = true //certmagic.DefaultACME.DisableTLSALPNChallenge = true
//certmagic.DefaultACME.CA = certmagic.LetsEncryptStagingCA //certmagic.DefaultACME.CA = certmagic.LetsEncryptStagingCA
@ -103,8 +114,6 @@ func Listen() {
func (p *BoringProxy) proxyRequest(w http.ResponseWriter, r *http.Request) { func (p *BoringProxy) proxyRequest(w http.ResponseWriter, r *http.Request) {
log.Println("proxy conn")
port, err := p.tunMan.GetPort(r.Host) port, err := p.tunMan.GetPort(r.Host)
if err != nil { if err != nil {
log.Print(err) log.Print(err)

View File

@ -16,7 +16,7 @@ type Database struct {
} }
type TokenData struct { type TokenData struct {
Id string `json:"id"` Owner string `json:"owner"`
} }
type User struct { type User struct {
@ -71,6 +71,27 @@ func NewDatabase() (*Database, error) {
return db, nil return db, nil
} }
func (d *Database) AddToken(owner string) (string, error) {
d.mutex.Lock()
defer d.mutex.Unlock()
_, exists := d.Users[owner]
if !exists {
return "", errors.New("Owner doesn't exist")
}
token, err := genRandomCode(32)
if err != nil {
return "", errors.New("Could not generat token")
}
d.Tokens[token] = TokenData{owner}
d.persist()
return token, nil
}
func (d *Database) GetTokenData(token string) (TokenData, bool) { func (d *Database) GetTokenData(token string) (TokenData, bool) {
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()

View File

@ -34,8 +34,8 @@ type ConfirmData struct {
} }
type AlertData struct { type AlertData struct {
Head template.HTML Head template.HTML
Message string Message string
RedirectUrl string RedirectUrl string
} }
@ -116,9 +116,9 @@ func (h *WebUiHandler) handleWebUiRequest(w http.ResponseWriter, r *http.Request
case "/users": case "/users":
h.users(w, r) h.users(w, r)
case "/confirm-delete-user": case "/confirm-delete-user":
h.confirmDeleteUser(w, r) h.confirmDeleteUser(w, r)
case "/delete-user": case "/delete-user":
h.deleteUser(w, r) h.deleteUser(w, r)
case "/": case "/":
indexTemplate, err := box.String("index.tmpl") indexTemplate, err := box.String("index.tmpl")
@ -158,12 +158,12 @@ func (h *WebUiHandler) handleWebUiRequest(w http.ResponseWriter, r *http.Request
} }
domain := r.Form["domain"][0] domain := r.Form["domain"][0]
tmpl, err := h.loadTemplate("confirm.tmpl") tmpl, err := h.loadTemplate("confirm.tmpl")
if err != nil { if err != nil {
w.WriteHeader(500) w.WriteHeader(500)
io.WriteString(w, err.Error()) io.WriteString(w, err.Error())
return return
} }
data := &ConfirmData{ data := &ConfirmData{
Head: h.headHtml, Head: h.headHtml,
@ -316,20 +316,20 @@ func (h *WebUiHandler) users(w http.ResponseWriter, r *http.Request) {
} }
username := r.Form["username"][0] username := r.Form["username"][0]
minUsernameLen := 6 minUsernameLen := 6
if len(username) < minUsernameLen { if len(username) < minUsernameLen {
w.WriteHeader(400) w.WriteHeader(400)
errStr := fmt.Sprintf("Username must be at least %d characters", minUsernameLen) errStr := fmt.Sprintf("Username must be at least %d characters", minUsernameLen)
h.alertDialog(w, r, errStr, "/users") h.alertDialog(w, r, errStr, "/users")
return return
} }
isAdmin := len(r.Form["is-admin"]) == 1 && r.Form["is-admin"][0] == "on" isAdmin := len(r.Form["is-admin"]) == 1 && r.Form["is-admin"][0] == "on"
err := h.db.AddUser(username, isAdmin) err := h.db.AddUser(username, isAdmin)
if err != nil { if err != nil {
w.WriteHeader(500) w.WriteHeader(500)
h.alertDialog(w, r, err.Error(), "/users") h.alertDialog(w, r, err.Error(), "/users")
return return
} }
@ -351,21 +351,21 @@ func (h *WebUiHandler) users(w http.ResponseWriter, r *http.Request) {
func (h *WebUiHandler) confirmDeleteUser(w http.ResponseWriter, r *http.Request) { func (h *WebUiHandler) confirmDeleteUser(w http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
if len(r.Form["username"]) != 1 { if len(r.Form["username"]) != 1 {
w.WriteHeader(400) w.WriteHeader(400)
w.Write([]byte("Invalid username parameter")) w.Write([]byte("Invalid username parameter"))
return return
} }
username := r.Form["username"][0] username := r.Form["username"][0]
tmpl, err := h.loadTemplate("confirm.tmpl") tmpl, err := h.loadTemplate("confirm.tmpl")
if err != nil { if err != nil {
w.WriteHeader(500) w.WriteHeader(500)
io.WriteString(w, err.Error()) io.WriteString(w, err.Error())
return return
} }
data := &ConfirmData{ data := &ConfirmData{
Head: h.headHtml, Head: h.headHtml,
@ -379,33 +379,33 @@ func (h *WebUiHandler) confirmDeleteUser(w http.ResponseWriter, r *http.Request)
func (h *WebUiHandler) deleteUser(w http.ResponseWriter, r *http.Request) { func (h *WebUiHandler) deleteUser(w http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
if len(r.Form["username"]) != 1 { if len(r.Form["username"]) != 1 {
w.WriteHeader(400) w.WriteHeader(400)
w.Write([]byte("Invalid username parameter")) w.Write([]byte("Invalid username parameter"))
return return
} }
username := r.Form["username"][0] username := r.Form["username"][0]
h.db.DeleteUser(username) h.db.DeleteUser(username)
http.Redirect(w, r, "/users", 303) http.Redirect(w, r, "/users", 303)
} }
func (h *WebUiHandler) alertDialog(w http.ResponseWriter, r *http.Request, message, redirectUrl string) error { func (h *WebUiHandler) alertDialog(w http.ResponseWriter, r *http.Request, message, redirectUrl string) error {
tmpl, err := h.loadTemplate("alert.tmpl") tmpl, err := h.loadTemplate("alert.tmpl")
if err != nil { if err != nil {
return err return err
} }
tmpl.Execute(w, &AlertData{ tmpl.Execute(w, &AlertData{
Head: h.headHtml, Head: h.headHtml,
Message: message, Message: message,
RedirectUrl: redirectUrl, RedirectUrl: redirectUrl,
}) })
return nil return nil
} }
func (h *WebUiHandler) loadTemplate(name string) (*template.Template, error) { func (h *WebUiHandler) loadTemplate(name string) (*template.Template, error) {

View File

@ -1,9 +1,11 @@
package main package main
import ( import (
"crypto/rand"
"encoding/json" "encoding/json"
"errors" "errors"
"io/ioutil" "io/ioutil"
"math/big"
"net/http" "net/http"
"strings" "strings"
) )
@ -21,26 +23,8 @@ func saveJson(data interface{}, filePath string) error {
return nil return nil
} }
// Looks for auth token in cookie, then header, then query string // Looks for auth token in query string, then headers, then cookies
func extractToken(tokenName string, r *http.Request) (string, error) { func extractToken(tokenName string, r *http.Request) (string, error) {
tokenCookie, err := r.Cookie(tokenName)
if err == nil {
return tokenCookie.Value, nil
}
tokenHeader := r.Header.Get(tokenName)
if tokenHeader != "" {
return tokenHeader, nil
}
authHeader := r.Header.Get("Authorization")
if authHeader != "" {
tokenHeader := strings.Split(authHeader, " ")[1]
return tokenHeader, nil
}
query := r.URL.Query() query := r.URL.Query()
@ -49,5 +33,35 @@ func extractToken(tokenName string, r *http.Request) (string, error) {
return queryToken, nil return queryToken, nil
} }
tokenHeader := r.Header.Get(tokenName)
if tokenHeader != "" {
return tokenHeader, nil
}
authHeader := r.Header.Get("Authorization")
if authHeader != "" {
tokenHeader := strings.Split(authHeader, " ")[1]
return tokenHeader, nil
}
tokenCookie, err := r.Cookie(tokenName)
if err == nil {
return tokenCookie.Value, nil
}
return "", errors.New("No token found") return "", errors.New("No token found")
} }
const chars string = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func genRandomCode(length int) (string, error) {
id := ""
for i := 0; i < length; i++ {
randIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(chars))))
if err != nil {
return "", err
}
id += string(chars[randIndex.Int64()])
}
return id, nil
}