mirror of
https://github.com/boringproxy/boringproxy.git
synced 2025-02-25 18:55:29 -06:00
Previously we were reading the entire downstream request into memory before making the new request to the upstream. Now we're just passing it through. Might be some dragons here (already ran into issues with Content-Length) but seems to be working so far.
72 lines
1.8 KiB
Go
72 lines
1.8 KiB
Go
package boringproxy
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
func proxyRequest(w http.ResponseWriter, r *http.Request, tunnel Tunnel, httpClient *http.Client, port int) {
|
|
|
|
if tunnel.AuthUsername != "" || tunnel.AuthPassword != "" {
|
|
username, password, ok := r.BasicAuth()
|
|
if !ok {
|
|
w.Header()["WWW-Authenticate"] = []string{"Basic"}
|
|
w.WriteHeader(401)
|
|
return
|
|
}
|
|
|
|
if username != tunnel.AuthUsername || password != tunnel.AuthPassword {
|
|
w.Header()["WWW-Authenticate"] = []string{"Basic"}
|
|
w.WriteHeader(401)
|
|
// TODO: should probably use a better form of rate limiting
|
|
time.Sleep(2 * time.Second)
|
|
return
|
|
}
|
|
}
|
|
|
|
downstreamReqHeaders := r.Header.Clone()
|
|
|
|
// TODO: should probably pass in address instead of using localhost,
|
|
// mostly for client-terminated TLS
|
|
upstreamAddr := fmt.Sprintf("localhost:%d", port)
|
|
upstreamUrl := fmt.Sprintf("http://%s%s", upstreamAddr, r.URL.RequestURI())
|
|
|
|
upstreamReq, err := http.NewRequest(r.Method, upstreamUrl, r.Body)
|
|
if err != nil {
|
|
errMessage := fmt.Sprintf("%s", err)
|
|
w.WriteHeader(500)
|
|
io.WriteString(w, errMessage)
|
|
return
|
|
}
|
|
|
|
// ContentLength needs to be set manually because otherwise it is
|
|
// stripped by golang. See:
|
|
// https://golang.org/pkg/net/http/#Request.Write
|
|
upstreamReq.ContentLength = r.ContentLength
|
|
|
|
upstreamReq.Header = downstreamReqHeaders
|
|
|
|
upstreamReq.Header["X-Forwarded-Host"] = []string{r.Host}
|
|
upstreamReq.Host = fmt.Sprintf("%s:%d", tunnel.ClientAddress, tunnel.ClientPort)
|
|
|
|
upstreamRes, err := httpClient.Do(upstreamReq)
|
|
if err != nil {
|
|
errMessage := fmt.Sprintf("%s", err)
|
|
w.WriteHeader(502)
|
|
io.WriteString(w, errMessage)
|
|
return
|
|
}
|
|
defer upstreamRes.Body.Close()
|
|
|
|
downstreamResHeaders := w.Header()
|
|
|
|
for k, v := range upstreamRes.Header {
|
|
downstreamResHeaders[k] = v
|
|
}
|
|
|
|
w.WriteHeader(upstreamRes.StatusCode)
|
|
io.Copy(w, upstreamRes.Body)
|
|
}
|