Merge pull request #33267 from hashicorp/sebasslash/snapshot-interval-header-check

[cloud] Add interval header check to enable snapshots
This commit is contained in:
Brandon Croft 2023-05-30 12:41:32 -06:00 committed by GitHub
commit ea0ebcdc05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 159 additions and 103 deletions

View File

@ -575,7 +575,7 @@ func (b *Cloud) DeleteWorkspace(name string, force bool) error {
}
// Configure the remote workspace name.
State := &State{tfeClient: b.client, organization: b.organization, workspace: workspace}
State := &State{tfeClient: b.client, organization: b.organization, workspace: workspace, enableIntermediateSnapshots: false}
return State.Delete(force)
}
@ -661,7 +661,7 @@ func (b *Cloud) StateMgr(name string) (statemgr.Full, error) {
}
}
return &State{tfeClient: b.client, organization: b.organization, workspace: workspace}, nil
return &State{tfeClient: b.client, organization: b.organization, workspace: workspace, enableIntermediateSnapshots: false}, nil
}
// Operation implements backend.Enhanced.

View File

@ -69,6 +69,9 @@ type State struct {
// not effect final snapshots after an operation, which will always
// be written to the remote API.
stateSnapshotInterval time.Duration
// If the header X-Terraform-Snapshot-Interval is present then
// we will enable snapshots
enableIntermediateSnapshots bool
}
var ErrStateVersionUnauthorizedUpgradeState = errors.New(strings.TrimSpace(`
@ -244,6 +247,10 @@ func (s *State) ShouldPersistIntermediateState(info *local.IntermediateStatePers
return true
}
if !s.enableIntermediateSnapshots && info.RequestedPersistInterval == time.Duration(0) {
return false
}
// Our persist interval is the largest of either the caller's requested
// interval or the server's requested interval.
wantInterval := info.RequestedPersistInterval
@ -278,25 +285,7 @@ func (s *State) uploadState(lineage string, serial uint64, isForcePush bool, sta
// The server is allowed to dynamically request a different time interval
// than we'd normally use, for example if it's currently under heavy load
// and needs clients to backoff for a while.
ctx = tfe.ContextWithResponseHeaderHook(ctx, func(status int, header http.Header) {
intervalStr := header.Get("x-terraform-snapshot-interval")
if intervalSecs, err := strconv.ParseInt(intervalStr, 10, 64); err == nil {
if intervalSecs > 3600 {
// More than an hour is an unreasonable delay, so we'll just
// saturate at one hour.
intervalSecs = 3600
} else if intervalSecs < 0 {
intervalSecs = 0
}
s.stateSnapshotInterval = time.Duration(intervalSecs) * time.Second
} else {
// If the header field is either absent or invalid then we'll
// just choose zero, which effectively means that we'll just use
// the caller's requested interval instead.
s.stateSnapshotInterval = time.Duration(0)
}
})
ctx = tfe.ContextWithResponseHeaderHook(ctx, s.readSnapshotIntervalHeader)
// Create the new state.
_, err := s.tfeClient.StateVersions.Create(ctx, s.workspace.ID, options)
@ -377,6 +366,10 @@ func (s *State) refreshState() error {
func (s *State) getStatePayload() (*remote.Payload, error) {
ctx := context.Background()
// Check the x-terraform-snapshot-interval header to see if it has a non-empty
// value which would indicate snapshots are enabled
ctx = tfe.ContextWithResponseHeaderHook(ctx, s.readSnapshotIntervalHeader)
sv, err := s.tfeClient.StateVersions.ReadCurrent(ctx, s.workspace.ID)
if err != nil {
if err == tfe.ErrResourceNotFound {
@ -542,6 +535,35 @@ func (s *State) GetRootOutputValues() (map[string]*states.OutputValue, error) {
return result, nil
}
func clamp(val, min, max int64) int64 {
if val < min {
return min
} else if val > max {
return max
}
return val
}
func (s *State) readSnapshotIntervalHeader(status int, header http.Header) {
intervalStr := header.Get("x-terraform-snapshot-interval")
if intervalSecs, err := strconv.ParseInt(intervalStr, 10, 64); err == nil {
// More than an hour is an unreasonable delay, so we'll just
// limit to one hour max.
intervalSecs = clamp(intervalSecs, 0, 3600)
s.stateSnapshotInterval = time.Duration(intervalSecs) * time.Second
} else {
// If the header field is either absent or invalid then we'll
// just choose zero, which effectively means that we'll just use
// the caller's requested interval instead. If the caller has no
// requested interval or it is zero, then we will disable snapshots.
s.stateSnapshotInterval = time.Duration(0)
}
// We will only enable snapshots for intervals greater than zero
s.enableIntermediateSnapshots = s.stateSnapshotInterval > 0
}
// tfeOutputToCtyValue decodes a combination of TFE output value and detailed-type to create a
// cty value that is suitable for use in terraform.
func tfeOutputToCtyValue(output tfe.StateVersionOutput) (cty.Value, error) {

View File

@ -6,11 +6,7 @@ package cloud
import (
"bytes"
"context"
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
@ -294,98 +290,71 @@ func TestState_PersistState(t *testing.T) {
t.Error("state manager already has a nonzero snapshot interval")
}
if cloudState.enableIntermediateSnapshots {
t.Error("expected state manager to have disabled snapshots")
}
// For this test we'll use a real client talking to a fake server,
// since HTTP-level concerns like headers are out of scope for the
// mock client we typically use in other tests in this package, which
// aim to abstract away HTTP altogether.
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Log(r.Method, r.URL.String())
if r.URL.Path == "/state-json" {
t.Log("pretending to be Archivist")
fakeState := states.NewState()
fakeStateFile := statefile.New(fakeState, "boop", 1)
var buf bytes.Buffer
statefile.Write(fakeStateFile, &buf)
respBody := buf.Bytes()
w.Header().Set("content-type", "application/json")
w.Header().Set("content-length", strconv.FormatInt(int64(len(respBody)), 10))
w.WriteHeader(http.StatusOK)
w.Write(respBody)
return
}
if r.URL.Path == "/api/ping" {
t.Log("pretending to be Ping")
w.WriteHeader(http.StatusNoContent)
return
}
// Didn't want to repeat myself here
for _, testCase := range []struct {
expectedInterval time.Duration
snapshotsEnabled bool
}{
{
expectedInterval: 300 * time.Second,
snapshotsEnabled: true,
},
{
expectedInterval: 0 * time.Second,
snapshotsEnabled: false,
},
} {
server := testServerWithSnapshotsEnabled(t, serverURL, testCase.snapshotsEnabled)
fakeBody := map[string]any{
"data": map[string]any{
"type": "state-versions",
"attributes": map[string]any{
"hosted-state-download-url": serverURL + "/state-json",
},
},
defer server.Close()
serverURL = server.URL
cfg := &tfe.Config{
Address: server.URL,
BasePath: "api",
Token: "placeholder",
}
fakeBodyRaw, err := json.Marshal(fakeBody)
client, err := tfe.NewClient(cfg)
if err != nil {
t.Fatal(err)
}
cloudState.tfeClient = client
err = cloudState.RefreshState()
if err != nil {
t.Fatal(err)
}
cloudState.WriteState(states.BuildState(func(s *states.SyncState) {
s.SetOutputValue(
addrs.OutputValue{Name: "boop"}.Absolute(addrs.RootModuleInstance),
cty.StringVal("beep"), false,
)
}))
err = cloudState.PersistState(nil)
if err != nil {
t.Fatal(err)
}
w.Header().Set("content-type", "application/json")
w.Header().Set("content-length", strconv.FormatInt(int64(len(fakeBodyRaw)), 10))
switch r.Method {
case "POST":
t.Log("pretending to be Create a State Version")
w.Header().Set("x-terraform-snapshot-interval", "300")
w.WriteHeader(http.StatusAccepted)
case "GET":
t.Log("pretending to be Fetch the Current State Version for a Workspace")
w.WriteHeader(http.StatusOK)
default:
t.Fatal("don't know what API operation this was supposed to be")
// The PersistState call above should have sent a request to the test
// server and got back the x-terraform-snapshot-interval header, whose
// value should therefore now be recorded in the relevant field.
if got := cloudState.stateSnapshotInterval; got != testCase.expectedInterval {
t.Errorf("wrong state snapshot interval after PersistState\ngot: %s\nwant: %s", got, testCase.expectedInterval)
}
w.WriteHeader(http.StatusOK)
w.Write(fakeBodyRaw)
}))
defer server.Close()
serverURL = server.URL
cfg := &tfe.Config{
Address: server.URL,
BasePath: "api",
Token: "placeholder",
}
client, err := tfe.NewClient(cfg)
if err != nil {
t.Fatal(err)
}
cloudState.tfeClient = client
err = cloudState.RefreshState()
if err != nil {
t.Fatal(err)
}
cloudState.WriteState(states.BuildState(func(s *states.SyncState) {
s.SetOutputValue(
addrs.OutputValue{Name: "boop"}.Absolute(addrs.RootModuleInstance),
cty.StringVal("beep"), false,
)
}))
err = cloudState.PersistState(nil)
if err != nil {
t.Fatal(err)
}
// The PersistState call above should have sent a request to the test
// server and got back the x-terraform-snapshot-interval header, whose
// value should therefore now be recorded in the relevant field.
if got, want := cloudState.stateSnapshotInterval, 300*time.Second; got != want {
t.Errorf("wrong state snapshot interval after PersistState\ngot: %s\nwant: %s", got, want)
if got, want := cloudState.enableIntermediateSnapshots, testCase.snapshotsEnabled; got != want {
t.Errorf("expected disable intermediate snapshots to be\ngot: %t\nwant: %t", got, want)
}
}
})
}

View File

@ -4,6 +4,7 @@
package cloud
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -12,6 +13,7 @@ import (
"net/http/httptest"
"net/url"
"path"
"strconv"
"testing"
"time"
@ -29,6 +31,8 @@ import (
"github.com/hashicorp/terraform/internal/configs/configschema"
"github.com/hashicorp/terraform/internal/httpclient"
"github.com/hashicorp/terraform/internal/providers"
"github.com/hashicorp/terraform/internal/states"
"github.com/hashicorp/terraform/internal/states/statefile"
"github.com/hashicorp/terraform/internal/terraform"
"github.com/hashicorp/terraform/internal/tfdiags"
"github.com/hashicorp/terraform/version"
@ -379,6 +383,67 @@ func testServerWithHandlers(handlers map[string]func(http.ResponseWriter, *http.
return httptest.NewServer(mux)
}
func testServerWithSnapshotsEnabled(t *testing.T, serverURL string, enabled bool) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Log(r.Method, r.URL.String())
if r.URL.Path == "/state-json" {
t.Log("pretending to be Archivist")
fakeState := states.NewState()
fakeStateFile := statefile.New(fakeState, "boop", 1)
var buf bytes.Buffer
statefile.Write(fakeStateFile, &buf)
respBody := buf.Bytes()
w.Header().Set("content-type", "application/json")
w.Header().Set("content-length", strconv.FormatInt(int64(len(respBody)), 10))
w.WriteHeader(http.StatusOK)
w.Write(respBody)
return
}
if r.URL.Path == "/api/ping" {
t.Log("pretending to be Ping")
w.WriteHeader(http.StatusNoContent)
return
}
fakeBody := map[string]any{
"data": map[string]any{
"type": "state-versions",
"attributes": map[string]any{
"hosted-state-download-url": serverURL + "/state-json",
},
},
}
fakeBodyRaw, err := json.Marshal(fakeBody)
if err != nil {
t.Fatal(err)
}
w.Header().Set("content-type", "application/json")
w.Header().Set("content-length", strconv.FormatInt(int64(len(fakeBodyRaw)), 10))
switch r.Method {
case "POST":
t.Log("pretending to be Create a State Version")
if enabled {
w.Header().Set("x-terraform-snapshot-interval", "300")
}
w.WriteHeader(http.StatusAccepted)
case "GET":
t.Log("pretending to be Fetch the Current State Version for a Workspace")
if enabled {
w.Header().Set("x-terraform-snapshot-interval", "300")
}
w.WriteHeader(http.StatusOK)
default:
t.Fatal("don't know what API operation this was supposed to be")
}
w.WriteHeader(http.StatusOK)
w.Write(fakeBodyRaw)
}))
}
// testDefaultRequestHandlers is a map of request handlers intended to be used in a request
// multiplexer for a test server. A caller may use testServerWithHandlers to start a server with
// this base set of routes, and override a particular route for whatever edge case is being tested.