Implement binding non-default ports

Can now bind to ports other than 80/443, using the -http-port and
-https-port arguments.

Assuming you already have the certs you need, HTTPS even works.

Unfortunately you can't get the certs automatically because
LetsEncrypt doesn't support ports other than 80/443 as far as I
know.
This commit is contained in:
Anders Pitman 2021-12-20 12:56:50 -07:00
parent cf281fa7f2
commit 30358d7808

View File

@ -67,15 +67,21 @@ func checkPublicAddress(host string, port int) error {
} }
}() }()
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, port)) addr := fmt.Sprintf("%s:%d", host, port)
conn, err := net.DialTimeout("tcp", addr, time.Second)
if err != nil { if err != nil {
return nil return err
} }
defer conn.Close() defer conn.Close()
go func() {
time.Sleep(time.Second)
conn.Close()
}()
data, err := io.ReadAll(conn) data, err := io.ReadAll(conn)
if err != nil { if err != nil {
return nil return errors.New(fmt.Sprintf("Error connecting to public address %s. Probably timed out", addr))
} }
retCode := string(data) retCode := string(data)
@ -110,6 +116,8 @@ func Listen() {
sshServerPort := flagSet.Int("ssh-server-port", 22, "SSH Server Port") sshServerPort := flagSet.Int("ssh-server-port", 22, "SSH Server Port")
certDir := flagSet.String("cert-dir", "", "TLS cert directory") certDir := flagSet.String("cert-dir", "", "TLS cert directory")
printLogin := flagSet.Bool("print-login", false, "Prints admin login information") printLogin := flagSet.Bool("print-login", false, "Prints admin login information")
httpPort := flagSet.Int("http-port", 80, "HTTP (insecure) port")
httpsPort := flagSet.Int("https-port", 443, "HTTPS (secure) port")
err := flagSet.Parse(os.Args[2:]) err := flagSet.Parse(os.Args[2:])
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s: parsing flags: %s\n", os.Args[0], err) fmt.Fprintf(os.Stderr, "%s: parsing flags: %s\n", os.Args[0], err)
@ -122,12 +130,12 @@ func Listen() {
log.Fatal(err) log.Fatal(err)
} }
err = checkPublicAddress(ip, 80) err = checkPublicAddress(ip, *httpPort)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
err = checkPublicAddress(ip, 443) err = checkPublicAddress(ip, *httpsPort)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -221,6 +229,11 @@ func Listen() {
timestamp := time.Now().Format(time.RFC3339) timestamp := time.Now().Format(time.RFC3339)
srcIp := strings.Split(r.RemoteAddr, ":")[0] srcIp := strings.Split(r.RemoteAddr, ":")[0]
fmt.Println(fmt.Sprintf("%s %s %s %s %s", timestamp, srcIp, r.Method, r.Host, r.URL.Path)) fmt.Println(fmt.Sprintf("%s %s %s %s %s", timestamp, srcIp, r.Method, r.Host, r.URL.Path))
// TODO: handle ipv6
hostParts := strings.Split(r.Host, ":")
hostDomain := hostParts[0]
if r.URL.Path == "/dnsapi/requests" { if r.URL.Path == "/dnsapi/requests" {
r.ParseForm() r.ParseForm()
@ -274,7 +287,7 @@ func Listen() {
http.Redirect(w, r, fmt.Sprintf("https://%s/edit-tunnel?domain=%s", adminDomain, domain), 303) http.Redirect(w, r, fmt.Sprintf("https://%s/edit-tunnel?domain=%s", adminDomain, domain), 303)
} }
} else if r.Host == db.GetAdminDomain() { } else if hostDomain == db.GetAdminDomain() {
if strings.HasPrefix(r.URL.Path, "/api/") { if strings.HasPrefix(r.URL.Path, "/api/") {
http.StripPrefix("/api", api).ServeHTTP(w, r) http.StripPrefix("/api", api).ServeHTTP(w, r)
} else { } else {
@ -282,9 +295,9 @@ func Listen() {
} }
} else { } else {
tunnel, exists := db.GetTunnel(r.Host) tunnel, exists := db.GetTunnel(hostDomain)
if !exists { if !exists {
errMessage := fmt.Sprintf("No tunnel attached to %s", r.Host) errMessage := fmt.Sprintf("No tunnel attached to %s", hostDomain)
w.WriteHeader(500) w.WriteHeader(500)
io.WriteString(w, errMessage) io.WriteString(w, errMessage)
return return
@ -295,14 +308,14 @@ func Listen() {
}) })
go func() { go func() {
if err := http.ListenAndServe(":80", nil); err != nil { if err := http.ListenAndServe(fmt.Sprintf(":%d", *httpPort), nil); err != nil {
log.Fatalf("ListenAndServe error: %v", err) log.Fatalf("ListenAndServe error: %v", err)
} }
}() }()
go http.Serve(tlsListener, nil) go http.Serve(tlsListener, nil)
listener, err := net.Listen("tcp", ":443") listener, err := net.Listen("tcp", fmt.Sprintf(":%d", *httpsPort))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }