diff --git a/cmd/boringproxy/main.go b/cmd/boringproxy/main.go index 4e28b61..5331dc1 100644 --- a/cmd/boringproxy/main.go +++ b/cmd/boringproxy/main.go @@ -2,9 +2,13 @@ package main import ( "context" + "crypto/tls" "flag" "fmt" + "io" + "net" "os" + "sync" "github.com/boringproxy/boringproxy" ) @@ -15,6 +19,7 @@ Commands: version Prints version information. server Start a new server. client Connect to a server. + tuntls Tunnel a raw TLS connection. Use "%[1]s command -h" for a list of flags for the command. ` @@ -25,7 +30,6 @@ func fail(msg string) { fmt.Fprintln(os.Stderr, msg) os.Exit(1) } - func main() { if len(os.Args) < 2 { fmt.Fprintln(os.Stderr, os.Args[0]+": Need a command") @@ -40,6 +44,44 @@ func main() { fmt.Println(Version) case "help", "-h", "--help", "-help": fmt.Printf(usage, os.Args[0]) + case "tuntls": + // This command is a direct port of https://github.com/anderspitman/tuntls + flagSet := flag.NewFlagSet(os.Args[0], flag.ExitOnError) + server := flagSet.String("server", "", "boringproxy server") + port := flagSet.Int("port", 0, "Local port to bind to") + err := flagSet.Parse(os.Args[2:]) + if err != nil { + fmt.Fprintf(os.Stderr, "%s: parsing flags: %s\n", os.Args[0], err) + os.Exit(1) + } + + if *server == "" { + fmt.Fprintf(os.Stderr, "server argument is required\n") + os.Exit(1) + } + + if *port == 0 { + // one-time tunnel over stdin/stdout + doTlsTunnel(*server, os.Stdin, os.Stdout) + } else { + // listen on a port and create tunnels for each connection + fmt.Fprintf(os.Stderr, "Listening on port %d\n", *port) + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } + + for { + conn, err := listener.Accept() + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } + + go doTlsTunnel(*server, conn, conn) + } + } case "server": boringproxy.Listen() case "client": @@ -97,3 +139,32 @@ func main() { fail(os.Args[0] + ": Invalid command " + command) } } + +func doTlsTunnel(server string, in io.Reader, out io.Writer) { + fmt.Fprintf(os.Stderr, "tuntls connecting to server: %s\n", server) + + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:443", server), &tls.Config{ + //RootCAs: roots, + }) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to connect: "+err.Error()) + os.Exit(1) + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + io.Copy(conn, in) + wg.Done() + }() + + go func() { + io.Copy(out, conn) + wg.Done() + }() + + wg.Wait() + + conn.Close() +}