From d03fd37ee62030b975b83b85a15b847c24a5da4e Mon Sep 17 00:00:00 2001 From: Sebastian Rivera Date: Wed, 24 May 2023 20:25:28 -0400 Subject: [PATCH 1/2] Add interval header check to enable snapshots --- internal/cloud/backend.go | 4 +- internal/cloud/state.go | 58 ++++++++++----- internal/cloud/state_test.go | 133 ++++++++++++++--------------------- internal/cloud/testing.go | 65 +++++++++++++++++ 4 files changed, 157 insertions(+), 103 deletions(-) diff --git a/internal/cloud/backend.go b/internal/cloud/backend.go index f0142b0bd2..fd49e56788 100644 --- a/internal/cloud/backend.go +++ b/internal/cloud/backend.go @@ -575,7 +575,7 @@ func (b *Cloud) DeleteWorkspace(name string, force bool) error { } // Configure the remote workspace name. - State := &State{tfeClient: b.client, organization: b.organization, workspace: workspace} + State := &State{tfeClient: b.client, organization: b.organization, workspace: workspace, disableIntermediateSnapshots: true} return State.Delete(force) } @@ -661,7 +661,7 @@ func (b *Cloud) StateMgr(name string) (statemgr.Full, error) { } } - return &State{tfeClient: b.client, organization: b.organization, workspace: workspace}, nil + return &State{tfeClient: b.client, organization: b.organization, workspace: workspace, disableIntermediateSnapshots: true}, nil } // Operation implements backend.Enhanced. diff --git a/internal/cloud/state.go b/internal/cloud/state.go index f3abc0b929..6555a4e01e 100644 --- a/internal/cloud/state.go +++ b/internal/cloud/state.go @@ -69,6 +69,9 @@ type State struct { // not effect final snapshots after an operation, which will always // be written to the remote API. stateSnapshotInterval time.Duration + // If the header, X-Terraform-Snapshot-Interval is not present then + // we will disable snapshots + disableIntermediateSnapshots bool } var ErrStateVersionUnauthorizedUpgradeState = errors.New(strings.TrimSpace(` @@ -244,6 +247,10 @@ func (s *State) ShouldPersistIntermediateState(info *local.IntermediateStatePers return true } + if s.disableIntermediateSnapshots && info.RequestedPersistInterval == time.Duration(0) { + return false + } + // Our persist interval is the largest of either the caller's requested // interval or the server's requested interval. wantInterval := info.RequestedPersistInterval @@ -278,25 +285,7 @@ func (s *State) uploadState(lineage string, serial uint64, isForcePush bool, sta // The server is allowed to dynamically request a different time interval // than we'd normally use, for example if it's currently under heavy load // and needs clients to backoff for a while. - ctx = tfe.ContextWithResponseHeaderHook(ctx, func(status int, header http.Header) { - intervalStr := header.Get("x-terraform-snapshot-interval") - - if intervalSecs, err := strconv.ParseInt(intervalStr, 10, 64); err == nil { - if intervalSecs > 3600 { - // More than an hour is an unreasonable delay, so we'll just - // saturate at one hour. - intervalSecs = 3600 - } else if intervalSecs < 0 { - intervalSecs = 0 - } - s.stateSnapshotInterval = time.Duration(intervalSecs) * time.Second - } else { - // If the header field is either absent or invalid then we'll - // just choose zero, which effectively means that we'll just use - // the caller's requested interval instead. - s.stateSnapshotInterval = time.Duration(0) - } - }) + ctx = tfe.ContextWithResponseHeaderHook(ctx, s.readSnapshotIntervalHeader) // Create the new state. _, err := s.tfeClient.StateVersions.Create(ctx, s.workspace.ID, options) @@ -377,6 +366,10 @@ func (s *State) refreshState() error { func (s *State) getStatePayload() (*remote.Payload, error) { ctx := context.Background() + // Check the x-terraform-snapshot-interval header to see if it has a non-empty + // value which would indicate snapshots are enabled + ctx = tfe.ContextWithResponseHeaderHook(ctx, s.readSnapshotIntervalHeader) + sv, err := s.tfeClient.StateVersions.ReadCurrent(ctx, s.workspace.ID) if err != nil { if err == tfe.ErrResourceNotFound { @@ -542,6 +535,33 @@ func (s *State) GetRootOutputValues() (map[string]*states.OutputValue, error) { return result, nil } +func (s *State) readSnapshotIntervalHeader(status int, header http.Header) { + intervalStr := header.Get("x-terraform-snapshot-interval") + + if intervalSecs, err := strconv.ParseInt(intervalStr, 10, 64); err == nil { + if intervalSecs > 3600 { + // More than an hour is an unreasonable delay, so we'll just + // saturate at one hour. + intervalSecs = 3600 + } else if intervalSecs < 0 { + intervalSecs = 0 + } + s.stateSnapshotInterval = time.Duration(intervalSecs) * time.Second + + // We will only enable snapshots for intervals greater than zero + if intervalSecs > 0 { + s.disableIntermediateSnapshots = false + } + } else { + // If the header field is either absent or invalid then we'll + // just choose zero, which effectively means that we'll just use + // the caller's requested interval instead. If the caller has no + // requested interval or it is zero, then we will disable snapshots. + s.stateSnapshotInterval = time.Duration(0) + s.disableIntermediateSnapshots = true + } +} + // tfeOutputToCtyValue decodes a combination of TFE output value and detailed-type to create a // cty value that is suitable for use in terraform. func tfeOutputToCtyValue(output tfe.StateVersionOutput) (cty.Value, error) { diff --git a/internal/cloud/state_test.go b/internal/cloud/state_test.go index d29ce5bb8b..c3d31c9259 100644 --- a/internal/cloud/state_test.go +++ b/internal/cloud/state_test.go @@ -6,11 +6,7 @@ package cloud import ( "bytes" "context" - "encoding/json" "io/ioutil" - "net/http" - "net/http/httptest" - "strconv" "testing" "time" @@ -294,98 +290,71 @@ func TestState_PersistState(t *testing.T) { t.Error("state manager already has a nonzero snapshot interval") } + if !cloudState.disableIntermediateSnapshots { + t.Error("expected state manager to have disabled snapshots") + } + // For this test we'll use a real client talking to a fake server, // since HTTP-level concerns like headers are out of scope for the // mock client we typically use in other tests in this package, which // aim to abstract away HTTP altogether. var serverURL string - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Log(r.Method, r.URL.String()) - if r.URL.Path == "/state-json" { - t.Log("pretending to be Archivist") - fakeState := states.NewState() - fakeStateFile := statefile.New(fakeState, "boop", 1) - var buf bytes.Buffer - statefile.Write(fakeStateFile, &buf) - respBody := buf.Bytes() - w.Header().Set("content-type", "application/json") - w.Header().Set("content-length", strconv.FormatInt(int64(len(respBody)), 10)) - w.WriteHeader(http.StatusOK) - w.Write(respBody) - return - } - if r.URL.Path == "/api/ping" { - t.Log("pretending to be Ping") - w.WriteHeader(http.StatusNoContent) - return - } + // Didn't want to repeat myself here + for _, testCase := range []struct { + expectedInterval time.Duration + snapshotsEnabled bool + }{ + { + expectedInterval: 300 * time.Second, + snapshotsEnabled: true, + }, + { + expectedInterval: 0 * time.Second, + snapshotsEnabled: false, + }, + } { + server := testServerWithSnapshotsEnabled(t, serverURL, testCase.snapshotsEnabled) - fakeBody := map[string]any{ - "data": map[string]any{ - "type": "state-versions", - "attributes": map[string]any{ - "hosted-state-download-url": serverURL + "/state-json", - }, - }, + defer server.Close() + serverURL = server.URL + cfg := &tfe.Config{ + Address: server.URL, + BasePath: "api", + Token: "placeholder", } - fakeBodyRaw, err := json.Marshal(fakeBody) + client, err := tfe.NewClient(cfg) + if err != nil { + t.Fatal(err) + } + cloudState.tfeClient = client + + err = cloudState.RefreshState() + if err != nil { + t.Fatal(err) + } + cloudState.WriteState(states.BuildState(func(s *states.SyncState) { + s.SetOutputValue( + addrs.OutputValue{Name: "boop"}.Absolute(addrs.RootModuleInstance), + cty.StringVal("beep"), false, + ) + })) + + err = cloudState.PersistState(nil) if err != nil { t.Fatal(err) } - w.Header().Set("content-type", "application/json") - w.Header().Set("content-length", strconv.FormatInt(int64(len(fakeBodyRaw)), 10)) - - switch r.Method { - case "POST": - t.Log("pretending to be Create a State Version") - w.Header().Set("x-terraform-snapshot-interval", "300") - w.WriteHeader(http.StatusAccepted) - case "GET": - t.Log("pretending to be Fetch the Current State Version for a Workspace") - w.WriteHeader(http.StatusOK) - default: - t.Fatal("don't know what API operation this was supposed to be") + // The PersistState call above should have sent a request to the test + // server and got back the x-terraform-snapshot-interval header, whose + // value should therefore now be recorded in the relevant field. + if got := cloudState.stateSnapshotInterval; got != testCase.expectedInterval { + t.Errorf("wrong state snapshot interval after PersistState\ngot: %s\nwant: %s", got, testCase.expectedInterval) } - w.WriteHeader(http.StatusOK) - w.Write(fakeBodyRaw) - })) - defer server.Close() - serverURL = server.URL - cfg := &tfe.Config{ - Address: server.URL, - BasePath: "api", - Token: "placeholder", - } - client, err := tfe.NewClient(cfg) - if err != nil { - t.Fatal(err) - } - cloudState.tfeClient = client - - err = cloudState.RefreshState() - if err != nil { - t.Fatal(err) - } - cloudState.WriteState(states.BuildState(func(s *states.SyncState) { - s.SetOutputValue( - addrs.OutputValue{Name: "boop"}.Absolute(addrs.RootModuleInstance), - cty.StringVal("beep"), false, - ) - })) - - err = cloudState.PersistState(nil) - if err != nil { - t.Fatal(err) - } - - // The PersistState call above should have sent a request to the test - // server and got back the x-terraform-snapshot-interval header, whose - // value should therefore now be recorded in the relevant field. - if got, want := cloudState.stateSnapshotInterval, 300*time.Second; got != want { - t.Errorf("wrong state snapshot interval after PersistState\ngot: %s\nwant: %s", got, want) + if got, want := cloudState.disableIntermediateSnapshots, !testCase.snapshotsEnabled; got != want { + t.Errorf("expected disable intermediate snapshots to be\ngot: %t\nwant: %t", got, want) + } } }) } diff --git a/internal/cloud/testing.go b/internal/cloud/testing.go index 0168f8b851..8f81f21fc1 100644 --- a/internal/cloud/testing.go +++ b/internal/cloud/testing.go @@ -4,6 +4,7 @@ package cloud import ( + "bytes" "context" "encoding/json" "fmt" @@ -12,6 +13,7 @@ import ( "net/http/httptest" "net/url" "path" + "strconv" "testing" "time" @@ -29,6 +31,8 @@ import ( "github.com/hashicorp/terraform/internal/configs/configschema" "github.com/hashicorp/terraform/internal/httpclient" "github.com/hashicorp/terraform/internal/providers" + "github.com/hashicorp/terraform/internal/states" + "github.com/hashicorp/terraform/internal/states/statefile" "github.com/hashicorp/terraform/internal/terraform" "github.com/hashicorp/terraform/internal/tfdiags" "github.com/hashicorp/terraform/version" @@ -379,6 +383,67 @@ func testServerWithHandlers(handlers map[string]func(http.ResponseWriter, *http. return httptest.NewServer(mux) } +func testServerWithSnapshotsEnabled(t *testing.T, serverURL string, enabled bool) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Log(r.Method, r.URL.String()) + + if r.URL.Path == "/state-json" { + t.Log("pretending to be Archivist") + fakeState := states.NewState() + fakeStateFile := statefile.New(fakeState, "boop", 1) + var buf bytes.Buffer + statefile.Write(fakeStateFile, &buf) + respBody := buf.Bytes() + w.Header().Set("content-type", "application/json") + w.Header().Set("content-length", strconv.FormatInt(int64(len(respBody)), 10)) + w.WriteHeader(http.StatusOK) + w.Write(respBody) + return + } + if r.URL.Path == "/api/ping" { + t.Log("pretending to be Ping") + w.WriteHeader(http.StatusNoContent) + return + } + + fakeBody := map[string]any{ + "data": map[string]any{ + "type": "state-versions", + "attributes": map[string]any{ + "hosted-state-download-url": serverURL + "/state-json", + }, + }, + } + fakeBodyRaw, err := json.Marshal(fakeBody) + if err != nil { + t.Fatal(err) + } + + w.Header().Set("content-type", "application/json") + w.Header().Set("content-length", strconv.FormatInt(int64(len(fakeBodyRaw)), 10)) + + switch r.Method { + case "POST": + t.Log("pretending to be Create a State Version") + if enabled { + w.Header().Set("x-terraform-snapshot-interval", "300") + } + w.WriteHeader(http.StatusAccepted) + case "GET": + t.Log("pretending to be Fetch the Current State Version for a Workspace") + if enabled { + w.Header().Set("x-terraform-snapshot-interval", "300") + } + w.WriteHeader(http.StatusOK) + default: + t.Fatal("don't know what API operation this was supposed to be") + } + + w.WriteHeader(http.StatusOK) + w.Write(fakeBodyRaw) + })) +} + // testDefaultRequestHandlers is a map of request handlers intended to be used in a request // multiplexer for a test server. A caller may use testServerWithHandlers to start a server with // this base set of routes, and override a particular route for whatever edge case is being tested. From 86eed095b34d0ee5e8d9afea03fb44cc5c304950 Mon Sep 17 00:00:00 2001 From: Brandon Croft Date: Tue, 30 May 2023 12:35:23 -0600 Subject: [PATCH 2/2] Rename disableIntermediateSnapshots > enableIntermediateSnapshots --- internal/cloud/backend.go | 4 ++-- internal/cloud/state.go | 36 +++++++++++++++++++----------------- internal/cloud/state_test.go | 4 ++-- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/internal/cloud/backend.go b/internal/cloud/backend.go index fd49e56788..b851fafb03 100644 --- a/internal/cloud/backend.go +++ b/internal/cloud/backend.go @@ -575,7 +575,7 @@ func (b *Cloud) DeleteWorkspace(name string, force bool) error { } // Configure the remote workspace name. - State := &State{tfeClient: b.client, organization: b.organization, workspace: workspace, disableIntermediateSnapshots: true} + State := &State{tfeClient: b.client, organization: b.organization, workspace: workspace, enableIntermediateSnapshots: false} return State.Delete(force) } @@ -661,7 +661,7 @@ func (b *Cloud) StateMgr(name string) (statemgr.Full, error) { } } - return &State{tfeClient: b.client, organization: b.organization, workspace: workspace, disableIntermediateSnapshots: true}, nil + return &State{tfeClient: b.client, organization: b.organization, workspace: workspace, enableIntermediateSnapshots: false}, nil } // Operation implements backend.Enhanced. diff --git a/internal/cloud/state.go b/internal/cloud/state.go index 6555a4e01e..69bafd01b3 100644 --- a/internal/cloud/state.go +++ b/internal/cloud/state.go @@ -69,9 +69,9 @@ type State struct { // not effect final snapshots after an operation, which will always // be written to the remote API. stateSnapshotInterval time.Duration - // If the header, X-Terraform-Snapshot-Interval is not present then - // we will disable snapshots - disableIntermediateSnapshots bool + // If the header X-Terraform-Snapshot-Interval is present then + // we will enable snapshots + enableIntermediateSnapshots bool } var ErrStateVersionUnauthorizedUpgradeState = errors.New(strings.TrimSpace(` @@ -247,7 +247,7 @@ func (s *State) ShouldPersistIntermediateState(info *local.IntermediateStatePers return true } - if s.disableIntermediateSnapshots && info.RequestedPersistInterval == time.Duration(0) { + if !s.enableIntermediateSnapshots && info.RequestedPersistInterval == time.Duration(0) { return false } @@ -535,31 +535,33 @@ func (s *State) GetRootOutputValues() (map[string]*states.OutputValue, error) { return result, nil } +func clamp(val, min, max int64) int64 { + if val < min { + return min + } else if val > max { + return max + } + return val +} + func (s *State) readSnapshotIntervalHeader(status int, header http.Header) { intervalStr := header.Get("x-terraform-snapshot-interval") if intervalSecs, err := strconv.ParseInt(intervalStr, 10, 64); err == nil { - if intervalSecs > 3600 { - // More than an hour is an unreasonable delay, so we'll just - // saturate at one hour. - intervalSecs = 3600 - } else if intervalSecs < 0 { - intervalSecs = 0 - } + // More than an hour is an unreasonable delay, so we'll just + // limit to one hour max. + intervalSecs = clamp(intervalSecs, 0, 3600) s.stateSnapshotInterval = time.Duration(intervalSecs) * time.Second - - // We will only enable snapshots for intervals greater than zero - if intervalSecs > 0 { - s.disableIntermediateSnapshots = false - } } else { // If the header field is either absent or invalid then we'll // just choose zero, which effectively means that we'll just use // the caller's requested interval instead. If the caller has no // requested interval or it is zero, then we will disable snapshots. s.stateSnapshotInterval = time.Duration(0) - s.disableIntermediateSnapshots = true } + + // We will only enable snapshots for intervals greater than zero + s.enableIntermediateSnapshots = s.stateSnapshotInterval > 0 } // tfeOutputToCtyValue decodes a combination of TFE output value and detailed-type to create a diff --git a/internal/cloud/state_test.go b/internal/cloud/state_test.go index c3d31c9259..14930a274d 100644 --- a/internal/cloud/state_test.go +++ b/internal/cloud/state_test.go @@ -290,7 +290,7 @@ func TestState_PersistState(t *testing.T) { t.Error("state manager already has a nonzero snapshot interval") } - if !cloudState.disableIntermediateSnapshots { + if cloudState.enableIntermediateSnapshots { t.Error("expected state manager to have disabled snapshots") } @@ -352,7 +352,7 @@ func TestState_PersistState(t *testing.T) { t.Errorf("wrong state snapshot interval after PersistState\ngot: %s\nwant: %s", got, testCase.expectedInterval) } - if got, want := cloudState.disableIntermediateSnapshots, !testCase.snapshotsEnabled; got != want { + if got, want := cloudState.enableIntermediateSnapshots, testCase.snapshotsEnabled; got != want { t.Errorf("expected disable intermediate snapshots to be\ngot: %t\nwant: %t", got, want) } }