diff --git a/svchost/disco/disco.go b/svchost/disco/disco.go new file mode 100644 index 0000000000..f3622bbfc6 --- /dev/null +++ b/svchost/disco/disco.go @@ -0,0 +1,177 @@ +// Package disco handles Terraform's remote service discovery protocol. +// +// This protocol allows mapping from a service hostname, as produced by the +// svchost package, to a set of services supported by that host and the +// endpoint information for each supported service. +package disco + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "mime" + "net/http" + "net/url" + "time" + + cleanhttp "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/terraform/svchost" + "github.com/hashicorp/terraform/terraform" +) + +const ( + discoPath = "/.well-known/terraform.json" + maxRedirects = 3 // arbitrary-but-small number to prevent runaway redirect loops + discoTimeout = 4 * time.Second // arbitrary-but-small time limit to prevent UI "hangs" during discovery + maxDiscoDocBytes = 1 * 1024 * 1024 // 1MB - to prevent abusive services from using loads of our memory +) + +var userAgent = fmt.Sprintf("Terraform/%s (service discovery)", terraform.VersionString()) +var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during tests, to skip TLS verification + +// Disco is the main type in this package, which allows discovery on given +// hostnames and caches the results by hostname to avoid repeated requests +// for the same information. +type Disco struct { + hostCache map[svchost.Hostname]Host +} + +func NewDisco() *Disco { + return &Disco{} +} + +// Discover runs the discovery protocol against the given hostname (which must +// already have been validated and prepared with svchost.ForComparison) and +// returns an object describing the services available at that host. +// +// If a given hostname supports no Terraform services at all, a non-nil but +// empty Host object is returned. When giving feedback to the end user about +// such situations, we say e.g. "the host doesn't provide a module +// registry", regardless of whether that is due to that service specifically +// being absent or due to the host not providing Terraform services at all, +// since we don't wish to expose the detail of whole-host discovery to an +// end-user. +func (d *Disco) Discover(host svchost.Hostname) Host { + if d.hostCache == nil { + d.hostCache = map[svchost.Hostname]Host{} + } + if cache, cached := d.hostCache[host]; cached { + return cache + } + + ret := d.discover(host) + d.hostCache[host] = ret + return ret +} + +// DiscoverServiceURL is a convenience wrapper for discovery on a given +// hostname and then looking up a particular service in the result. +func (d *Disco) DiscoverServiceURL(host svchost.Hostname, serviceID string) *url.URL { + return d.Discover(host).ServiceURL(serviceID) +} + +// discover implements the actual discovery process, with its result cached +// by the public-facing Discover method. +func (d *Disco) discover(host svchost.Hostname) Host { + discoURL := &url.URL{ + Scheme: "https", + Host: string(host), + Path: discoPath, + } + client := &http.Client{ + Transport: httpTransport, + Timeout: discoTimeout, + + CheckRedirect: func(req *http.Request, via []*http.Request) error { + log.Printf("[DEBUG] Service discovery redirected to %s", req.URL) + if len(via) > maxRedirects { + return errors.New("too many redirects") // (this error message will never actually be seen) + } + return nil + }, + } + + var header = http.Header{} + header.Set("User-Agent", userAgent) + // TODO: look up credentials and add them to the header if we have them + + req := &http.Request{ + Method: "GET", + URL: discoURL, + Header: header, + } + + log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL) + + ret := Host{ + discoURL: discoURL, + } + + resp, err := client.Do(req) + if err != nil { + log.Printf("[WARNING] Failed to request discovery document: %s", err) + return ret // empty + } + if resp.StatusCode != 200 { + log.Printf("[WARNING] Failed to request discovery document: %s", resp.Status) + return ret // empty + } + + // If the client followed any redirects, we will have a new URL to use + // as our base for relative resolution. + ret.discoURL = resp.Request.URL + + contentType := resp.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + log.Printf("[WARNING] Discovery URL has malformed Content-Type %q", contentType) + return ret // empty + } + if mediaType != "application/json" { + log.Printf("[DEBUG] Discovery URL returned Content-Type %q, rather than application/json", mediaType) + return ret // empty + } + + // (this doesn't catch chunked encoding, because ContentLength is -1 in that case...) + if resp.ContentLength > maxDiscoDocBytes { + // Size limit here is not a contractual requirement and so we may + // adjust it over time if we find a different limit is warranted. + log.Printf("[WARNING] Discovery doc response is too large (got %d bytes; limit %d)", resp.ContentLength, maxDiscoDocBytes) + return ret // empty + } + + // If the response is using chunked encoding then we can't predict + // its size, but we'll at least prevent reading the entire thing into + // memory. + lr := io.LimitReader(resp.Body, maxDiscoDocBytes) + + servicesBytes, err := ioutil.ReadAll(lr) + if err != nil { + log.Printf("[WARNING] Error reading discovery document body: %s", err) + return ret // empty + } + + var services map[string]interface{} + err = json.Unmarshal(servicesBytes, &services) + if err != nil { + log.Printf("[WARNING] Failed to decode discovery document as a JSON object: %s", err) + return ret // empty + } + + ret.services = services + return ret +} + +// Forget invalidates any cached record of the given hostname. If the host +// has no cache entry then this is a no-op. +func (d *Disco) Forget(host svchost.Hostname) { + delete(d.hostCache, host) +} + +// ForgetAll is like Forget, but for all of the hostnames that have cache entries. +func (d *Disco) ForgetAll() { + d.hostCache = nil +} diff --git a/svchost/disco/disco_test.go b/svchost/disco/disco_test.go new file mode 100644 index 0000000000..d514d1f3fa --- /dev/null +++ b/svchost/disco/disco_test.go @@ -0,0 +1,255 @@ +package disco + +import ( + "crypto/tls" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strconv" + "testing" + + "github.com/hashicorp/terraform/svchost" +) + +func TestMain(m *testing.M) { + // During all tests we override the HTTP transport we use for discovery + // so it'll tolerate the locally-generated TLS certificates we use + // for test URLs. + httpTransport = &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + os.Exit(m.Run()) +} + +func TestDiscover(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) { + resp := []byte(` +{ +"thingy.v1": "http://example.com/foo", +"wotsit.v2": "http://example.net/bar" +} +`) + w.Header().Add("Content-Type", "application/json") + w.Header().Add("Content-Length", strconv.Itoa(len(resp))) + w.Write(resp) + }) + defer close() + + givenHost := "localhost" + portStr + host, err := svchost.ForComparison(givenHost) + if err != nil { + t.Fatalf("test server hostname is invalid: %s", err) + } + + d := NewDisco() + discovered := d.Discover(host) + gotURL := discovered.ServiceURL("thingy.v1") + if gotURL == nil { + t.Fatalf("found no URL for thingy.v1") + } + if got, want := gotURL.String(), "http://example.com/foo"; got != want { + t.Fatalf("wrong result %q; want %q", got, want) + } + }) + t.Run("chunked encoding", func(t *testing.T) { + portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) { + resp := []byte(` +{ +"thingy.v1": "http://example.com/foo", +"wotsit.v2": "http://example.net/bar" +} +`) + w.Header().Add("Content-Type", "application/json") + // We're going to force chunked encoding here -- and thus prevent + // the server from predicting the length -- so we can make sure + // our client is tolerant of servers using this encoding. + w.Write(resp[:5]) + w.(http.Flusher).Flush() + w.Write(resp[5:]) + w.(http.Flusher).Flush() + }) + defer close() + + givenHost := "localhost" + portStr + host, err := svchost.ForComparison(givenHost) + if err != nil { + t.Fatalf("test server hostname is invalid: %s", err) + } + + d := NewDisco() + discovered := d.Discover(host) + gotURL := discovered.ServiceURL("wotsit.v2") + if gotURL == nil { + t.Fatalf("found no URL for wotsit.v2") + } + if got, want := gotURL.String(), "http://example.net/bar"; got != want { + t.Fatalf("wrong result %q; want %q", got, want) + } + }) + t.Run("not JSON", func(t *testing.T) { + portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) { + resp := []byte(`{"thingy.v1": "http://example.com/foo"}`) + w.Header().Add("Content-Type", "application/octet-stream") + w.Write(resp) + }) + defer close() + + givenHost := "localhost" + portStr + host, err := svchost.ForComparison(givenHost) + if err != nil { + t.Fatalf("test server hostname is invalid: %s", err) + } + + d := NewDisco() + discovered := d.Discover(host) + + // result should be empty, which we can verify only by reaching into + // its internals. + if discovered.services != nil { + t.Errorf("response not empty; should be") + } + }) + t.Run("malformed JSON", func(t *testing.T) { + portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) { + resp := []byte(`{"thingy.v1": "htt`) // truncated, for example... + w.Header().Add("Content-Type", "application/json") + w.Write(resp) + }) + defer close() + + givenHost := "localhost" + portStr + host, err := svchost.ForComparison(givenHost) + if err != nil { + t.Fatalf("test server hostname is invalid: %s", err) + } + + d := NewDisco() + discovered := d.Discover(host) + + // result should be empty, which we can verify only by reaching into + // its internals. + if discovered.services != nil { + t.Errorf("response not empty; should be") + } + }) + t.Run("JSON with redundant charset", func(t *testing.T) { + // The JSON RFC defines no parameters for the application/json + // MIME type, but some servers have a weird tendency to just add + // "charset" to everything, so we'll make sure we ignore it successfully. + // (JSON uses content sniffing for encoding detection, not media type params.) + portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) { + resp := []byte(`{"thingy.v1": "http://example.com/foo"}`) + w.Header().Add("Content-Type", "application/json; charset=latin-1") + w.Write(resp) + }) + defer close() + + givenHost := "localhost" + portStr + host, err := svchost.ForComparison(givenHost) + if err != nil { + t.Fatalf("test server hostname is invalid: %s", err) + } + + d := NewDisco() + discovered := d.Discover(host) + + if discovered.services == nil { + t.Errorf("response is empty; shouldn't be") + } + }) + t.Run("no discovery doc", func(t *testing.T) { + portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + defer close() + + givenHost := "localhost" + portStr + host, err := svchost.ForComparison(givenHost) + if err != nil { + t.Fatalf("test server hostname is invalid: %s", err) + } + + d := NewDisco() + discovered := d.Discover(host) + + // result should be empty, which we can verify only by reaching into + // its internals. + if discovered.services != nil { + t.Errorf("response not empty; should be") + } + }) + t.Run("redirect", func(t *testing.T) { + // For this test, we have two servers and one redirects to the other + portStr1, close1 := testServer(func(w http.ResponseWriter, r *http.Request) { + // This server is the one that returns a real response. + resp := []byte(`{"thingy.v1": "http://example.com/foo"}`) + w.Header().Add("Content-Type", "application/json") + w.Header().Add("Content-Length", strconv.Itoa(len(resp))) + w.Write(resp) + }) + portStr2, close2 := testServer(func(w http.ResponseWriter, r *http.Request) { + // This server is the one that redirects. + http.Redirect(w, r, "https://127.0.0.1"+portStr1+"/.well-known/terraform.json", 302) + }) + defer close1() + defer close2() + + givenHost := "localhost" + portStr2 + host, err := svchost.ForComparison(givenHost) + if err != nil { + t.Fatalf("test server hostname is invalid: %s", err) + } + + d := NewDisco() + discovered := d.Discover(host) + + gotURL := discovered.ServiceURL("thingy.v1") + if gotURL == nil { + t.Fatalf("found no URL for thingy.v1") + } + if got, want := gotURL.String(), "http://example.com/foo"; got != want { + t.Fatalf("wrong result %q; want %q", got, want) + } + + // The base URL for the host object should be the URL we redirected to, + // rather than the we redirected _from_. + gotBaseURL := discovered.discoURL.String() + wantBaseURL := "https://127.0.0.1" + portStr1 + "/.well-known/terraform.json" + if gotBaseURL != wantBaseURL { + t.Errorf("incorrect base url %s; want %s", gotBaseURL, wantBaseURL) + } + + }) +} + +func testServer(h func(w http.ResponseWriter, r *http.Request)) (portStr string, close func()) { + server := httptest.NewTLSServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // Test server always returns 404 if the URL isn't what we expect + if r.URL.Path != "/.well-known/terraform.json" { + w.WriteHeader(404) + w.Write([]byte("not found")) + return + } + + // If the URL is correct then the given hander decides the response + h(w, r) + }, + )) + + serverURL, _ := url.Parse(server.URL) + + portStr = serverURL.Port() + if portStr != "" { + portStr = ":" + portStr + } + + close = func() { + server.Close() + } + + return +} diff --git a/svchost/disco/host.go b/svchost/disco/host.go new file mode 100644 index 0000000000..faf58220af --- /dev/null +++ b/svchost/disco/host.go @@ -0,0 +1,51 @@ +package disco + +import ( + "net/url" +) + +type Host struct { + discoURL *url.URL + services map[string]interface{} +} + +// ServiceURL returns the URL associated with the given service identifier, +// which should be of the form "servicename.vN". +// +// A non-nil result is always an absolute URL with a scheme of either https +// or http. +// +// If the requested service is not supported by the host, this method returns +// a nil URL. +// +// If the discovery document entry for the given service is invalid (not a URL), +// it is treated as absent, also returning a nil URL. +func (h Host) ServiceURL(id string) *url.URL { + if h.services == nil { + return nil // no services supported for an empty Host + } + + urlStr, ok := h.services[id].(string) + if !ok { + return nil + } + + ret, err := url.Parse(urlStr) + if err != nil { + return nil + } + if !ret.IsAbs() { + ret = h.discoURL.ResolveReference(ret) // make absolute using our discovery doc URL + } + if ret.Scheme != "https" && ret.Scheme != "http" { + return nil + } + if ret.User != nil { + // embedded username/password information is not permitted; credentials + // are handled out of band. + return nil + } + ret.Fragment = "" // fragment part is irrelevant, since we're not a browser + + return h.discoURL.ResolveReference(ret) +} diff --git a/svchost/disco/host_test.go b/svchost/disco/host_test.go new file mode 100644 index 0000000000..8a9fe4c761 --- /dev/null +++ b/svchost/disco/host_test.go @@ -0,0 +1,55 @@ +package disco + +import ( + "net/url" + "testing" +) + +func TestHostServiceURL(t *testing.T) { + baseURL, _ := url.Parse("https://example.com/disco/foo.json") + host := Host{ + discoURL: baseURL, + services: map[string]interface{}{ + "absolute.v1": "http://example.net/foo/bar", + "absolutewithport.v1": "http://example.net:8080/foo/bar", + "relative.v1": "./stu/", + "rootrelative.v1": "/baz", + "protorelative.v1": "//example.net/", + "withfragment.v1": "http://example.org/#foo", + "querystring.v1": "https://example.net/baz?foo=bar", + "nothttp.v1": "ftp://127.0.0.1/pub/", + "invalid.v1": "***not A URL at all!:/<@@@@>***", + }, + } + + tests := []struct { + ID string + Want string + }{ + {"absolute.v1", "http://example.net/foo/bar"}, + {"absolutewithport.v1", "http://example.net:8080/foo/bar"}, + {"relative.v1", "https://example.com/disco/stu/"}, + {"rootrelative.v1", "https://example.com/baz"}, + {"protorelative.v1", "https://example.net/"}, + {"withfragment.v1", "http://example.org/"}, + {"querystring.v1", "https://example.net/baz?foo=bar"}, // most callers will disregard query string + {"nothttp.v1", ""}, + {"invalid.v1", ""}, + } + + for _, test := range tests { + t.Run(test.ID, func(t *testing.T) { + url := host.ServiceURL(test.ID) + var got string + if url != nil { + got = url.String() + } else { + got = "" + } + + if got != test.Want { + t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.Want) + } + }) + } +}