mirror of
https://github.com/boringproxy/boringproxy.git
synced 2025-02-25 18:55:29 -06:00
Implement client TLS termination
Managed to reuse the same proxy function the server uses.
This commit is contained in:
parent
14a666481a
commit
560d682a31
@ -2,13 +2,11 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/caddyserver/certmagic"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -122,7 +120,16 @@ func Listen() {
|
||||
webUiHandler.handleWebUiRequest(w, r)
|
||||
}
|
||||
} else {
|
||||
p.proxyRequest(w, r)
|
||||
|
||||
tunnel, exists := db.GetTunnel(r.Host)
|
||||
if !exists {
|
||||
errMessage := fmt.Sprintf("No tunnel attached to %s", r.Host)
|
||||
w.WriteHeader(500)
|
||||
io.WriteString(w, errMessage)
|
||||
return
|
||||
}
|
||||
|
||||
proxyRequest(w, r, tunnel, httpClient, tunnel.TunnelPort)
|
||||
}
|
||||
})
|
||||
|
||||
@ -163,7 +170,7 @@ func (p *BoringProxy) handleConnection(clientConn net.Conn) {
|
||||
|
||||
tunnel, exists := p.db.GetTunnel(clientHello.ServerName)
|
||||
|
||||
if exists && tunnel.TlsPassthrough {
|
||||
if exists && (tunnel.TlsTermination == "client" || tunnel.TlsTermination == "passthrough") {
|
||||
p.passthroughRequest(passConn, tunnel)
|
||||
} else {
|
||||
p.httpListener.PassConn(passConn)
|
||||
@ -198,78 +205,6 @@ func (p *BoringProxy) passthroughRequest(conn net.Conn, tunnel Tunnel) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (p *BoringProxy) proxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
tunnel, exists := p.db.GetTunnel(r.Host)
|
||||
if !exists {
|
||||
errMessage := fmt.Sprintf("No tunnel attached to %s", r.Host)
|
||||
w.WriteHeader(500)
|
||||
io.WriteString(w, errMessage)
|
||||
return
|
||||
}
|
||||
|
||||
if tunnel.AuthUsername != "" || tunnel.AuthPassword != "" {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
w.Header()["WWW-Authenticate"] = []string{"Basic"}
|
||||
w.WriteHeader(401)
|
||||
return
|
||||
}
|
||||
|
||||
if username != tunnel.AuthUsername || password != tunnel.AuthPassword {
|
||||
w.Header()["WWW-Authenticate"] = []string{"Basic"}
|
||||
w.WriteHeader(401)
|
||||
// TODO: should probably use a better form of rate limiting
|
||||
time.Sleep(2 * time.Second)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
downstreamReqHeaders := r.Header.Clone()
|
||||
|
||||
upstreamAddr := fmt.Sprintf("localhost:%d", tunnel.TunnelPort)
|
||||
upstreamUrl := fmt.Sprintf("http://%s%s", upstreamAddr, r.URL.RequestURI())
|
||||
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errMessage := fmt.Sprintf("%s", err)
|
||||
w.WriteHeader(500)
|
||||
io.WriteString(w, errMessage)
|
||||
return
|
||||
}
|
||||
|
||||
upstreamReq, err := http.NewRequest(r.Method, upstreamUrl, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
errMessage := fmt.Sprintf("%s", err)
|
||||
w.WriteHeader(500)
|
||||
io.WriteString(w, errMessage)
|
||||
return
|
||||
}
|
||||
|
||||
upstreamReq.Header = downstreamReqHeaders
|
||||
|
||||
upstreamReq.Header["X-Forwarded-Host"] = []string{r.Host}
|
||||
upstreamReq.Host = fmt.Sprintf("%s:%d", tunnel.ClientAddress, tunnel.ClientPort)
|
||||
|
||||
upstreamRes, err := p.httpClient.Do(upstreamReq)
|
||||
if err != nil {
|
||||
errMessage := fmt.Sprintf("%s", err)
|
||||
w.WriteHeader(502)
|
||||
io.WriteString(w, errMessage)
|
||||
return
|
||||
}
|
||||
defer upstreamRes.Body.Close()
|
||||
|
||||
downstreamResHeaders := w.Header()
|
||||
|
||||
for k, v := range upstreamRes.Header {
|
||||
downstreamResHeaders[k] = v
|
||||
}
|
||||
|
||||
w.WriteHeader(upstreamRes.StatusCode)
|
||||
io.Copy(w, upstreamRes.Body)
|
||||
}
|
||||
|
||||
func redirectTLS(w http.ResponseWriter, r *http.Request) {
|
||||
url := fmt.Sprintf("https://%s:443%s", r.Host, r.RequestURI)
|
||||
http.Redirect(w, r, url, http.StatusMovedPermanently)
|
||||
|
34
client.go
34
client.go
@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/caddyserver/certmagic"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -29,6 +30,7 @@ type BoringProxyClient struct {
|
||||
user string
|
||||
cancelFuncs map[string]context.CancelFunc
|
||||
cancelFuncsMutex *sync.Mutex
|
||||
certConfig *certmagic.Config
|
||||
}
|
||||
|
||||
func NewBoringProxyClient() *BoringProxyClient {
|
||||
@ -39,6 +41,9 @@ func NewBoringProxyClient() *BoringProxyClient {
|
||||
user := flagSet.String("user", "admin", "user")
|
||||
flagSet.Parse(os.Args[2:])
|
||||
|
||||
certmagic.DefaultACME.DisableHTTPChallenge = true
|
||||
certConfig := certmagic.NewDefault()
|
||||
|
||||
httpClient := &http.Client{}
|
||||
tunnels := make(map[string]Tunnel)
|
||||
cancelFuncs := make(map[string]context.CancelFunc)
|
||||
@ -54,6 +59,7 @@ func NewBoringProxyClient() *BoringProxyClient {
|
||||
user: *user,
|
||||
cancelFuncs: cancelFuncs,
|
||||
cancelFuncsMutex: cancelFuncsMutex,
|
||||
certConfig: certConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@ -218,6 +224,33 @@ func (c *BoringProxyClient) BoreTunnel(tunnel Tunnel) context.CancelFunc {
|
||||
}
|
||||
//defer listener.Close()
|
||||
|
||||
if tunnel.TlsTermination == "client" {
|
||||
// TODO: There's still quite a bit of duplication with what the server does. Could we
|
||||
// encapsulate it into a type?
|
||||
err = c.certConfig.ManageSync([]string{tunnel.Domain})
|
||||
if err != nil {
|
||||
log.Println("CertMagic error at startup")
|
||||
log.Println(err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
GetCertificate: c.certConfig.GetCertificate,
|
||||
NextProtos: []string{"h2", "acme-tls/1"},
|
||||
}
|
||||
tlsListener := tls.NewListener(listener, tlsConfig)
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
proxyRequest(w, r, tunnel, c.httpClient, tunnel.ClientPort)
|
||||
})
|
||||
|
||||
// TODO: It seems inefficient to make a separate HTTP server for each TLS-passthrough tunnel,
|
||||
// but the code is much simpler. The only alternative I've thought of so far involves storing
|
||||
// all the tunnels in a mutexed map and retrieving them from a single HTTP server, same as the
|
||||
// boringproxy server does.
|
||||
go http.Serve(tlsListener, nil)
|
||||
|
||||
} else {
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
@ -233,6 +266,7 @@ func (c *BoringProxyClient) BoreTunnel(tunnel Tunnel) context.CancelFunc {
|
||||
go c.handleConnection(conn, tunnel.ClientAddress, tunnel.ClientPort)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
listener.Close()
|
||||
|
@ -50,7 +50,7 @@ type Tunnel struct {
|
||||
AuthUsername string `json:"auth_username"`
|
||||
AuthPassword string `json:"auth_password"`
|
||||
CssId string `json:"css_id"`
|
||||
TlsPassthrough bool `json:"tls_passthrough"`
|
||||
TlsTermination string `json:"tls_termination"`
|
||||
}
|
||||
|
||||
func NewDatabase() (*Database, error) {
|
||||
|
74
http_proxy.go
Normal file
74
http_proxy.go
Normal file
@ -0,0 +1,74 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func proxyRequest(w http.ResponseWriter, r *http.Request, tunnel Tunnel, httpClient *http.Client, port int) {
|
||||
|
||||
if tunnel.AuthUsername != "" || tunnel.AuthPassword != "" {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
w.Header()["WWW-Authenticate"] = []string{"Basic"}
|
||||
w.WriteHeader(401)
|
||||
return
|
||||
}
|
||||
|
||||
if username != tunnel.AuthUsername || password != tunnel.AuthPassword {
|
||||
w.Header()["WWW-Authenticate"] = []string{"Basic"}
|
||||
w.WriteHeader(401)
|
||||
// TODO: should probably use a better form of rate limiting
|
||||
time.Sleep(2 * time.Second)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
downstreamReqHeaders := r.Header.Clone()
|
||||
|
||||
upstreamAddr := fmt.Sprintf("localhost:%d", port)
|
||||
upstreamUrl := fmt.Sprintf("http://%s%s", upstreamAddr, r.URL.RequestURI())
|
||||
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errMessage := fmt.Sprintf("%s", err)
|
||||
w.WriteHeader(500)
|
||||
io.WriteString(w, errMessage)
|
||||
return
|
||||
}
|
||||
|
||||
upstreamReq, err := http.NewRequest(r.Method, upstreamUrl, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
errMessage := fmt.Sprintf("%s", err)
|
||||
w.WriteHeader(500)
|
||||
io.WriteString(w, errMessage)
|
||||
return
|
||||
}
|
||||
|
||||
upstreamReq.Header = downstreamReqHeaders
|
||||
|
||||
upstreamReq.Header["X-Forwarded-Host"] = []string{r.Host}
|
||||
upstreamReq.Host = fmt.Sprintf("%s:%d", tunnel.ClientAddress, tunnel.ClientPort)
|
||||
|
||||
upstreamRes, err := httpClient.Do(upstreamReq)
|
||||
if err != nil {
|
||||
errMessage := fmt.Sprintf("%s", err)
|
||||
w.WriteHeader(502)
|
||||
io.WriteString(w, errMessage)
|
||||
return
|
||||
}
|
||||
defer upstreamRes.Body.Close()
|
||||
|
||||
downstreamResHeaders := w.Header()
|
||||
|
||||
for k, v := range upstreamRes.Header {
|
||||
downstreamResHeaders[k] = v
|
||||
}
|
||||
|
||||
w.WriteHeader(upstreamRes.StatusCode)
|
||||
io.Copy(w, upstreamRes.Body)
|
||||
}
|
6
sni.go
6
sni.go
@ -1,5 +1,4 @@
|
||||
// NOTE: The code in this file was mostly copied from this very helpful
|
||||
// article:
|
||||
// NOTE: A lot of this code was copied from this very helpful article:
|
||||
// https://www.agwa.name/blog/post/writing_an_sni_proxy_in_go
|
||||
|
||||
package main
|
||||
@ -92,7 +91,8 @@ func (c ProxyConn) CloseWrite() error { return c.conn.(*net.TCPConn).C
|
||||
func (c ProxyConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
|
||||
func (c ProxyConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
|
||||
|
||||
// TODO: is this safe? Will it actually close properly?
|
||||
// TODO: is this safe? Will it actually close properly, or does it need to be
|
||||
// connected to the reader somehow?
|
||||
func (c ProxyConn) Close() error { return c.conn.Close() }
|
||||
func (c ProxyConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
||||
func (c ProxyConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
|
||||
|
Loading…
Reference in New Issue
Block a user