From 275dd116f92a9c526968a1b748f5b70b45617d86 Mon Sep 17 00:00:00 2001 From: Marcin Wyszynski Date: Wed, 25 Oct 2023 12:37:58 +0200 Subject: [PATCH] Pass context to all remote.Client operations (#786) Signed-off-by: Marcin Wyszynski --- internal/backend/local/backend.go | 7 +- internal/backend/local/backend_local.go | 2 +- internal/backend/local/backend_local_test.go | 8 +-- internal/backend/local/hook_state.go | 5 +- internal/backend/local/hook_state_test.go | 5 +- .../remote-state/azure/backend_state.go | 6 +- internal/backend/remote-state/azure/client.go | 7 +- .../backend/remote-state/azure/client_test.go | 2 +- .../remote-state/consul/backend_state.go | 4 +- .../backend/remote-state/consul/client.go | 4 +- .../remote-state/consul/client_test.go | 4 +- .../backend/remote-state/cos/backend_state.go | 4 +- internal/backend/remote-state/cos/client.go | 20 +++--- .../backend/remote-state/gcs/backend_state.go | 4 +- internal/backend/remote-state/gcs/client.go | 8 +-- internal/backend/remote-state/http/client.go | 8 +-- .../backend/remote-state/http/server_test.go | 8 +-- .../backend/remote-state/inmem/backend.go | 2 +- .../remote-state/inmem/backend_test.go | 4 +- internal/backend/remote-state/inmem/client.go | 4 +- .../remote-state/kubernetes/backend_state.go | 4 +- .../backend/remote-state/kubernetes/client.go | 7 +- .../backend/remote-state/oss/backend_state.go | 4 +- internal/backend/remote-state/oss/client.go | 4 +- .../backend/remote-state/oss/client_test.go | 18 ++--- .../backend/remote-state/pg/backend_state.go | 2 +- .../backend/remote-state/pg/backend_test.go | 6 +- internal/backend/remote-state/pg/client.go | 8 +-- .../backend/remote-state/s3/backend_state.go | 4 +- .../backend/remote-state/s3/backend_test.go | 16 ++--- internal/backend/remote-state/s3/client.go | 6 +- .../backend/remote-state/s3/client_test.go | 18 ++--- internal/backend/remote/backend_context.go | 2 +- internal/backend/remote/backend_state.go | 8 +-- internal/backend/remote/backend_state_test.go | 2 +- internal/backend/testing.go | 20 +++--- .../builtin/providers/tf/data_source_state.go | 2 +- internal/cloud/backend_context.go | 2 +- internal/cloud/state.go | 28 +++----- internal/cloud/state_test.go | 23 +++--- internal/command/import.go | 6 +- internal/command/init.go | 2 +- internal/command/meta_backend.go | 4 +- internal/command/meta_backend_migrate.go | 16 ++--- internal/command/meta_backend_test.go | 72 +++++++++---------- internal/command/output.go | 2 +- internal/command/providers.go | 2 +- internal/command/refresh_test.go | 3 +- internal/command/show.go | 2 +- internal/command/state_list.go | 2 +- internal/command/state_mv.go | 11 +-- internal/command/state_pull.go | 2 +- internal/command/state_push.go | 4 +- internal/command/state_push_test.go | 2 +- internal/command/state_replace_provider.go | 7 +- internal/command/state_rm.go | 7 +- internal/command/state_show.go | 2 +- internal/command/taint.go | 4 +- internal/command/untaint.go | 4 +- internal/command/workspace_command_test.go | 2 +- internal/command/workspace_delete.go | 2 +- internal/command/workspace_new.go | 2 +- internal/states/remote/remote.go | 4 +- internal/states/remote/remote_test.go | 14 ++-- internal/states/remote/state.go | 19 ++--- internal/states/remote/state_test.go | 27 ++++--- internal/states/remote/testing.go | 10 +-- internal/states/statemgr/filesystem.go | 11 +-- internal/states/statemgr/filesystem_test.go | 9 +-- internal/states/statemgr/helper.go | 6 +- internal/states/statemgr/lock.go | 14 ++-- internal/states/statemgr/persistent.go | 8 ++- internal/states/statemgr/statemgr_fake.go | 13 ++-- internal/states/statemgr/testing.go | 13 ++-- 74 files changed, 321 insertions(+), 287 deletions(-) diff --git a/internal/backend/local/backend.go b/internal/backend/local/backend.go index 63744b0da4..b182e8da2f 100644 --- a/internal/backend/local/backend.go +++ b/internal/backend/local/backend.go @@ -348,9 +348,14 @@ func (b *Local) opWait( case <-stopCtx.Done(): view.Stopping() + // We want to have a context that's guaranteed to be active that can be + // used to persist the state. Otherwise, if the operation is canceled + // or stopped before we can persist the state, we'll lose the state. + persistCtx := context.Background() + // try to force a PersistState just in case the process is terminated // before we can complete. - if err := opStateMgr.PersistState(nil); err != nil { + if err := opStateMgr.PersistState(persistCtx, nil); err != nil { // We can't error out from here, but warn the user if there was an error. // If this isn't transient, we will catch it again below, and // attempt to save the state another way. diff --git a/internal/backend/local/backend_local.go b/internal/backend/local/backend_local.go index 27e6ee19fb..135ca44bc8 100644 --- a/internal/backend/local/backend_local.go +++ b/internal/backend/local/backend_local.go @@ -62,7 +62,7 @@ func (b *Local) localRun(op *backend.Operation) (*backend.LocalRun, *configload. }() log.Printf("[TRACE] backend/local: reading remote state for workspace %q", op.Workspace) - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { diags = diags.Append(fmt.Errorf("error loading state: %w", err)) return nil, nil, nil, diags } diff --git a/internal/backend/local/backend_local_test.go b/internal/backend/local/backend_local_test.go index 2a1aaba676..6dc49d6a80 100644 --- a/internal/backend/local/backend_local_test.go +++ b/internal/backend/local/backend_local_test.go @@ -145,7 +145,7 @@ func TestLocalRun_stalePlan(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := sm.RefreshState(); err != nil { + if err := sm.RefreshState(ctx); err != nil { t.Fatalf("unexpected error refreshing state: %s", err) } @@ -263,7 +263,7 @@ func (s *stateStorageThatFailsRefresh) State() *states.State { return nil } -func (s *stateStorageThatFailsRefresh) GetRootOutputValues() (map[string]*states.OutputValue, error) { +func (s *stateStorageThatFailsRefresh) GetRootOutputValues(context.Context) (map[string]*states.OutputValue, error) { return nil, fmt.Errorf("unimplemented") } @@ -271,10 +271,10 @@ func (s *stateStorageThatFailsRefresh) WriteState(*states.State) error { return fmt.Errorf("unimplemented") } -func (s *stateStorageThatFailsRefresh) RefreshState() error { +func (s *stateStorageThatFailsRefresh) RefreshState(context.Context) error { return fmt.Errorf("intentionally failing for testing purposes") } -func (s *stateStorageThatFailsRefresh) PersistState(schemas *tofu.Schemas) error { +func (s *stateStorageThatFailsRefresh) PersistState(_ context.Context, schemas *tofu.Schemas) error { return fmt.Errorf("unimplemented") } diff --git a/internal/backend/local/hook_state.go b/internal/backend/local/hook_state.go index 4062bd8671..58198b570b 100644 --- a/internal/backend/local/hook_state.go +++ b/internal/backend/local/hook_state.go @@ -4,6 +4,7 @@ package local import ( + "context" "log" "sync" "time" @@ -79,7 +80,7 @@ func (h *StateHook) PostStateUpdate(new *states.State) (tofu.HookAction, error) } if mgrPersist, ok := h.StateMgr.(statemgr.Persister); ok && h.PersistInterval != 0 && h.Schemas != nil { if h.shouldPersist() { - err := mgrPersist.PersistState(h.Schemas) + err := mgrPersist.PersistState(context.TODO(), h.Schemas) if err != nil { return tofu.HookActionHalt, err } @@ -113,7 +114,7 @@ func (h *StateHook) Stopping() { h.intermediatePersist.ForcePersist = true if h.shouldPersist() { - err := mgrPersist.PersistState(h.Schemas) + err := mgrPersist.PersistState(context.TODO(), h.Schemas) if err != nil { // This hook can't affect OpenTofu Core's ongoing behavior, // but it's a best effort thing anyway, so we'll just emit a diff --git a/internal/backend/local/hook_state_test.go b/internal/backend/local/hook_state_test.go index c12944833b..0898e8d3c4 100644 --- a/internal/backend/local/hook_state_test.go +++ b/internal/backend/local/hook_state_test.go @@ -4,6 +4,7 @@ package local import ( + "context" "fmt" "testing" "time" @@ -255,7 +256,7 @@ func (sm *testPersistentState) WriteState(state *states.State) error { return nil } -func (sm *testPersistentState) PersistState(schemas *tofu.Schemas) error { +func (sm *testPersistentState) PersistState(_ context.Context, schemas *tofu.Schemas) error { if schemas == nil { return fmt.Errorf("no schemas") } @@ -281,7 +282,7 @@ func (sm *testPersistentStateThatRefusesToPersist) WriteState(state *states.Stat return nil } -func (sm *testPersistentStateThatRefusesToPersist) PersistState(schemas *tofu.Schemas) error { +func (sm *testPersistentStateThatRefusesToPersist) PersistState(_ context.Context, schemas *tofu.Schemas) error { if schemas == nil { return fmt.Errorf("no schemas") } diff --git a/internal/backend/remote-state/azure/backend_state.go b/internal/backend/remote-state/azure/backend_state.go index 83282566eb..c115d6da2f 100644 --- a/internal/backend/remote-state/azure/backend_state.go +++ b/internal/backend/remote-state/azure/backend_state.go @@ -96,7 +96,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err stateMgr := &remote.State{Client: client} // Grab the value - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { return nil, err } //if this isn't the default state name, we need to create the object so @@ -119,7 +119,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err } // Grab the value - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { err = lockUnlock(err) return nil, err } @@ -131,7 +131,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err err = lockUnlock(err) return nil, err } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(ctx, nil); err != nil { err = lockUnlock(err) return nil, err } diff --git a/internal/backend/remote-state/azure/client.go b/internal/backend/remote-state/azure/client.go index 8fff2471a9..8e586579c6 100644 --- a/internal/backend/remote-state/azure/client.go +++ b/internal/backend/remote-state/azure/client.go @@ -33,13 +33,12 @@ type RemoteClient struct { snapshot bool } -func (c *RemoteClient) Get() (*remote.Payload, error) { +func (c *RemoteClient) Get(ctx context.Context) (*remote.Payload, error) { options := blobs.GetInput{} if c.leaseID != "" { options.LeaseID = &c.leaseID } - ctx := context.TODO() blob, err := c.giovanniBlobClient.Get(ctx, c.accountName, c.containerName, c.keyName, options) if err != nil { if blob.Response.IsHTTPStatus(http.StatusNotFound) { @@ -60,7 +59,7 @@ func (c *RemoteClient) Get() (*remote.Payload, error) { return payload, nil } -func (c *RemoteClient) Put(data []byte) error { +func (c *RemoteClient) Put(ctx context.Context, data []byte) error { getOptions := blobs.GetPropertiesInput{} setOptions := blobs.SetPropertiesInput{} putOptions := blobs.PutBlockBlobInput{} @@ -73,8 +72,6 @@ func (c *RemoteClient) Put(data []byte) error { putOptions.LeaseID = &c.leaseID } - ctx := context.TODO() - if c.snapshot { snapshotInput := blobs.SnapshotInput{LeaseID: options.LeaseID} diff --git a/internal/backend/remote-state/azure/client_test.go b/internal/backend/remote-state/azure/client_test.go index 761be91582..e23e755683 100644 --- a/internal/backend/remote-state/azure/client_test.go +++ b/internal/backend/remote-state/azure/client_test.go @@ -297,7 +297,7 @@ func TestPutMaintainsMetaData(t *testing.T) { } bytes := []byte(acctest.RandString(20)) - err = remoteClient.Put(bytes) + err = remoteClient.Put(ctx, bytes) if err != nil { t.Fatalf("Error putting data: %+v", err) } diff --git a/internal/backend/remote-state/consul/backend_state.go b/internal/backend/remote-state/consul/backend_state.go index 9044200987..498699ea83 100644 --- a/internal/backend/remote-state/consul/backend_state.go +++ b/internal/backend/remote-state/consul/backend_state.go @@ -113,7 +113,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err } // Grab the value - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { err = lockUnlock(err) return nil, err } @@ -124,7 +124,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err err = lockUnlock(err) return nil, err } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(ctx, nil); err != nil { err = lockUnlock(err) return nil, err } diff --git a/internal/backend/remote-state/consul/client.go b/internal/backend/remote-state/consul/client.go index b9811d9ac8..6e6fd27467 100644 --- a/internal/backend/remote-state/consul/client.go +++ b/internal/backend/remote-state/consul/client.go @@ -71,7 +71,7 @@ type RemoteClient struct { sessionCancel context.CancelFunc } -func (c *RemoteClient) Get() (*remote.Payload, error) { +func (c *RemoteClient) Get(context.Context) (*remote.Payload, error) { c.mu.Lock() defer c.mu.Unlock() @@ -123,7 +123,7 @@ func (c *RemoteClient) Get() (*remote.Payload, error) { }, nil } -func (c *RemoteClient) Put(data []byte) error { +func (c *RemoteClient) Put(_ context.Context, data []byte) error { // The state can be stored in 4 different ways, based on the payload size // and whether the user enabled gzip: // - single entry mode with plain JSON: a single JSON is stored at diff --git a/internal/backend/remote-state/consul/client_test.go b/internal/backend/remote-state/consul/client_test.go index 4d6423e11d..93b31fbcc3 100644 --- a/internal/backend/remote-state/consul/client_test.go +++ b/internal/backend/remote-state/consul/client_test.go @@ -141,12 +141,12 @@ func TestConsul_largeState(t *testing.T) { if err != nil { t.Fatal(err) } - err = c.Put(payload) + err = c.Put(ctx, payload) if err != nil { t.Fatal("could not put payload", err) } - remote, err := c.Get() + remote, err := c.Get(ctx) if err != nil { t.Fatal(err) } diff --git a/internal/backend/remote-state/cos/backend_state.go b/internal/backend/remote-state/cos/backend_state.go index 8f08a0246b..cf55c07459 100644 --- a/internal/backend/remote-state/cos/backend_state.go +++ b/internal/backend/remote-state/cos/backend_state.go @@ -119,7 +119,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err } // Grab the value - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { err = lockUnlock(err) return nil, err } @@ -130,7 +130,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err err = lockUnlock(err) return nil, err } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(ctx, nil); err != nil { err = lockUnlock(err) return nil, err } diff --git a/internal/backend/remote-state/cos/client.go b/internal/backend/remote-state/cos/client.go index 105070dd48..ce025e9032 100644 --- a/internal/backend/remote-state/cos/client.go +++ b/internal/backend/remote-state/cos/client.go @@ -42,10 +42,10 @@ type remoteClient struct { } // Get returns remote state file -func (c *remoteClient) Get() (*remote.Payload, error) { +func (c *remoteClient) Get(ctx context.Context) (*remote.Payload, error) { log.Printf("[DEBUG] get remote state file %s", c.stateFile) - exists, data, checksum, err := c.getObject(c.stateFile) + exists, data, checksum, err := c.getObject(ctx, c.stateFile) if err != nil { return nil, err } @@ -63,10 +63,10 @@ func (c *remoteClient) Get() (*remote.Payload, error) { } // Put put state file to remote -func (c *remoteClient) Put(data []byte) error { +func (c *remoteClient) Put(ctx context.Context, data []byte) error { log.Printf("[DEBUG] put remote state file %s", c.stateFile) - return c.putObject(c.stateFile, data) + return c.putObject(ctx, c.stateFile, data) } // Delete delete remote state file @@ -86,7 +86,7 @@ func (c *remoteClient) Lock(info *statemgr.LockInfo) (string, error) { } defer c.cosUnlock(c.bucket, c.lockFile) - exists, _, _, err := c.getObject(c.lockFile) + exists, _, _, err := c.getObject(c.cosContext, c.lockFile) if err != nil { return "", c.lockError(err) } @@ -102,7 +102,7 @@ func (c *remoteClient) Lock(info *statemgr.LockInfo) (string, error) { } check := fmt.Sprintf("%x", md5.Sum(data)) - err = c.putObject(c.lockFile, data) + err = c.putObject(c.cosContext, c.lockFile, data) if err != nil { return "", c.lockError(err) } @@ -156,7 +156,7 @@ func (c *remoteClient) lockError(err error) *statemgr.LockError { // lockInfo returns LockInfo from lock file func (c *remoteClient) lockInfo() (*statemgr.LockInfo, error) { - exists, data, checksum, err := c.getObject(c.lockFile) + exists, data, checksum, err := c.getObject(c.cosContext, c.lockFile) if err != nil { return nil, err } @@ -176,8 +176,8 @@ func (c *remoteClient) lockInfo() (*statemgr.LockInfo, error) { } // getObject get remote object -func (c *remoteClient) getObject(cosFile string) (exists bool, data []byte, checksum string, err error) { - rsp, err := c.cosClient.Object.Get(c.cosContext, cosFile, nil) +func (c *remoteClient) getObject(ctx context.Context, cosFile string) (exists bool, data []byte, checksum string, err error) { + rsp, err := c.cosClient.Object.Get(ctx, cosFile, nil) if rsp == nil { log.Printf("[DEBUG] getObject %s: error: %v", cosFile, err) err = fmt.Errorf("failed to open file at %v: %w", cosFile, err) @@ -221,7 +221,7 @@ func (c *remoteClient) getObject(cosFile string) (exists bool, data []byte, chec } // putObject put object to remote -func (c *remoteClient) putObject(cosFile string, data []byte) error { +func (c *remoteClient) putObject(ctx context.Context, cosFile string, data []byte) error { opt := &cos.ObjectPutOptions{ ObjectPutHeaderOptions: &cos.ObjectPutHeaderOptions{ XCosMetaXXX: &http.Header{ diff --git a/internal/backend/remote-state/gcs/backend_state.go b/internal/backend/remote-state/gcs/backend_state.go index 8b40a312c7..512bf86fe6 100644 --- a/internal/backend/remote-state/gcs/backend_state.go +++ b/internal/backend/remote-state/gcs/backend_state.go @@ -100,7 +100,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err st := &remote.State{Client: c} // Grab the value - if err := st.RefreshState(); err != nil { + if err := st.RefreshState(ctx); err != nil { return nil, err } @@ -136,7 +136,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err if err := st.WriteState(states.NewState()); err != nil { return nil, unlock(err) } - if err := st.PersistState(nil); err != nil { + if err := st.PersistState(ctx, nil); err != nil { return nil, unlock(err) } diff --git a/internal/backend/remote-state/gcs/client.go b/internal/backend/remote-state/gcs/client.go index 90d2c06ebb..1006a898c5 100644 --- a/internal/backend/remote-state/gcs/client.go +++ b/internal/backend/remote-state/gcs/client.go @@ -31,8 +31,8 @@ type remoteClient struct { kmsKeyName string } -func (c *remoteClient) Get() (payload *remote.Payload, err error) { - stateFileReader, err := c.stateFile().NewReader(c.storageContext) +func (c *remoteClient) Get(ctx context.Context) (payload *remote.Payload, err error) { + stateFileReader, err := c.stateFile().NewReader(ctx) if err != nil { if err == storage.ErrObjectNotExist { return nil, nil @@ -60,9 +60,9 @@ func (c *remoteClient) Get() (payload *remote.Payload, err error) { return result, nil } -func (c *remoteClient) Put(data []byte) error { +func (c *remoteClient) Put(ctx context.Context, data []byte) error { err := func() error { - stateFileWriter := c.stateFile().NewWriter(c.storageContext) + stateFileWriter := c.stateFile().NewWriter(ctx) if len(c.kmsKeyName) > 0 { stateFileWriter.KMSKeyName = c.kmsKeyName } diff --git a/internal/backend/remote-state/http/client.go b/internal/backend/remote-state/http/client.go index cefba83ade..38b2fb42d4 100644 --- a/internal/backend/remote-state/http/client.go +++ b/internal/backend/remote-state/http/client.go @@ -144,8 +144,8 @@ func (c *httpClient) Unlock(id string) error { } } -func (c *httpClient) Get() (*remote.Payload, error) { - resp, err := c.httpRequest(context.TODO(), "GET", c.URL, nil, "get state") +func (c *httpClient) Get(ctx context.Context) (*remote.Payload, error) { + resp, err := c.httpRequest(ctx, "GET", c.URL, nil, "get state") if err != nil { return nil, err } @@ -203,7 +203,7 @@ func (c *httpClient) Get() (*remote.Payload, error) { return payload, nil } -func (c *httpClient) Put(data []byte) error { +func (c *httpClient) Put(ctx context.Context, data []byte) error { // Copy the target URL base := *c.URL @@ -226,7 +226,7 @@ func (c *httpClient) Put(data []byte) error { if c.UpdateMethod != "" { method = c.UpdateMethod } - resp, err := c.httpRequest(context.TODO(), method, &base, &data, "upload state") + resp, err := c.httpRequest(ctx, method, &base, &data, "upload state") if err != nil { return err } diff --git a/internal/backend/remote-state/http/server_test.go b/internal/backend/remote-state/http/server_test.go index 2041424a12..faea030af1 100644 --- a/internal/backend/remote-state/http/server_test.go +++ b/internal/backend/remote-state/http/server_test.go @@ -288,7 +288,7 @@ func TestMTLSServer_NoCertFails(t *testing.T) { } opErr := new(net.OpError) - err = sm.RefreshState() + err = sm.RefreshState(ctx) if err == nil { t.Fatal("expected error when refreshing state without a client cert") } @@ -359,7 +359,7 @@ func TestMTLSServer_WithCertPasses(t *testing.T) { if err != nil { t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err) } - if err = sm.RefreshState(); err != nil { + if err = sm.RefreshState(ctx); err != nil { t.Fatalf("unexpected error calling RefreshState: %v", err) } state := sm.State() @@ -400,10 +400,10 @@ func TestMTLSServer_WithCertPasses(t *testing.T) { if err = sm.WriteState(state); err != nil { t.Errorf("error writing state: %v", err) } - if err = sm.PersistState(nil); err != nil { + if err = sm.PersistState(ctx, nil); err != nil { t.Errorf("error persisting state: %v", err) } - if err = sm.RefreshState(); err != nil { + if err = sm.RefreshState(ctx); err != nil { t.Errorf("error refreshing state: %v", err) } diff --git a/internal/backend/remote-state/inmem/backend.go b/internal/backend/remote-state/inmem/backend.go index a2cab039c8..5a8bf35316 100644 --- a/internal/backend/remote-state/inmem/backend.go +++ b/internal/backend/remote-state/inmem/backend.go @@ -144,7 +144,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err if err := s.WriteState(statespkg.NewState()); err != nil { return nil, err } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { return nil, err } } diff --git a/internal/backend/remote-state/inmem/backend_test.go b/internal/backend/remote-state/inmem/backend_test.go index cbbbbaad12..36f6fc6cb0 100644 --- a/internal/backend/remote-state/inmem/backend_test.go +++ b/internal/backend/remote-state/inmem/backend_test.go @@ -90,11 +90,11 @@ func TestRemoteState(t *testing.T) { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatal(err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatal(err) } } diff --git a/internal/backend/remote-state/inmem/client.go b/internal/backend/remote-state/inmem/client.go index 9533525a58..c69e493cf3 100644 --- a/internal/backend/remote-state/inmem/client.go +++ b/internal/backend/remote-state/inmem/client.go @@ -18,7 +18,7 @@ type RemoteClient struct { Name string } -func (c *RemoteClient) Get() (*remote.Payload, error) { +func (c *RemoteClient) Get(context.Context) (*remote.Payload, error) { if c.Data == nil { return nil, nil } @@ -29,7 +29,7 @@ func (c *RemoteClient) Get() (*remote.Payload, error) { }, nil } -func (c *RemoteClient) Put(data []byte) error { +func (c *RemoteClient) Put(_ context.Context, data []byte) error { md5 := md5.Sum(data) c.Data = data diff --git a/internal/backend/remote-state/kubernetes/backend_state.go b/internal/backend/remote-state/kubernetes/backend_state.go index 72cdecb8f0..55439d7874 100644 --- a/internal/backend/remote-state/kubernetes/backend_state.go +++ b/internal/backend/remote-state/kubernetes/backend_state.go @@ -85,7 +85,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err stateMgr := &remote.State{Client: c} // Grab the value - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { return nil, err } @@ -126,7 +126,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err if err := stateMgr.WriteState(states.NewState()); err != nil { return nil, unlock(err) } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(ctx, nil); err != nil { return nil, unlock(err) } diff --git a/internal/backend/remote-state/kubernetes/client.go b/internal/backend/remote-state/kubernetes/client.go index e0b04c61eb..bf1878abff 100644 --- a/internal/backend/remote-state/kubernetes/client.go +++ b/internal/backend/remote-state/kubernetes/client.go @@ -45,12 +45,12 @@ type RemoteClient struct { workspace string } -func (c *RemoteClient) Get() (payload *remote.Payload, err error) { +func (c *RemoteClient) Get(ctx context.Context) (payload *remote.Payload, err error) { secretName, err := c.createSecretName() if err != nil { return nil, err } - secret, err := c.kubernetesSecretClient.Get(context.Background(), secretName, metav1.GetOptions{}) + secret, err := c.kubernetesSecretClient.Get(ctx, secretName, metav1.GetOptions{}) if err != nil { if k8serrors.IsNotFound(err) { return nil, nil @@ -81,8 +81,7 @@ func (c *RemoteClient) Get() (payload *remote.Payload, err error) { return p, nil } -func (c *RemoteClient) Put(data []byte) error { - ctx := context.Background() +func (c *RemoteClient) Put(ctx context.Context, data []byte) error { secretName, err := c.createSecretName() if err != nil { return err diff --git a/internal/backend/remote-state/oss/backend_state.go b/internal/backend/remote-state/oss/backend_state.go index ffbcc777a6..bdd5d8691d 100644 --- a/internal/backend/remote-state/oss/backend_state.go +++ b/internal/backend/remote-state/oss/backend_state.go @@ -153,7 +153,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err } // Grab the value - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { err = lockUnlock(err) return nil, err } @@ -164,7 +164,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err err = lockUnlock(err) return nil, err } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(ctx, nil); err != nil { err = lockUnlock(err) return nil, err } diff --git a/internal/backend/remote-state/oss/client.go b/internal/backend/remote-state/oss/client.go index 3745653ffd..de944bc738 100644 --- a/internal/backend/remote-state/oss/client.go +++ b/internal/backend/remote-state/oss/client.go @@ -54,7 +54,7 @@ type RemoteClient struct { otsTable string } -func (c *RemoteClient) Get() (payload *remote.Payload, err error) { +func (c *RemoteClient) Get(context.Context) (payload *remote.Payload, err error) { deadline := time.Now().Add(consistencyRetryTimeout) // If we have a checksum, and the returned payload doesn't match, we retry @@ -97,7 +97,7 @@ func (c *RemoteClient) Get() (payload *remote.Payload, err error) { return payload, nil } -func (c *RemoteClient) Put(data []byte) error { +func (c *RemoteClient) Put(_ context.Context, data []byte) error { bucket, err := c.ossClient.Bucket(c.bucketName) if err != nil { return fmt.Errorf("error getting bucket: %w", err) diff --git a/internal/backend/remote-state/oss/client_test.go b/internal/backend/remote-state/oss/client_test.go index 4aab249df0..406f1203b5 100644 --- a/internal/backend/remote-state/oss/client_test.go +++ b/internal/backend/remote-state/oss/client_test.go @@ -328,22 +328,22 @@ func TestRemoteClient_stateChecksum(t *testing.T) { client2 := s2.(*remote.State).Client // write the new state through client2 so that there is no checksum yet - if err := client2.Put(newState.Bytes()); err != nil { + if err := client2.Put(ctx, newState.Bytes()); err != nil { t.Fatal(err) } // verify that we can pull a state without a checksum - if _, err := client1.Get(); err != nil { + if _, err := client1.Get(ctx); err != nil { t.Fatal(err) } // write the new state back with its checksum - if err := client1.Put(newState.Bytes()); err != nil { + if err := client1.Put(ctx, newState.Bytes()); err != nil { t.Fatal(err) } // put an empty state in place to check for panics during get - if err := client2.Put([]byte{}); err != nil { + if err := client2.Put(ctx, []byte{}); err != nil { t.Fatal(err) } @@ -359,24 +359,24 @@ func TestRemoteClient_stateChecksum(t *testing.T) { // fetching an empty state through client1 should now error out due to a // mismatched checksum. - if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { + if _, err := client1.Get(ctx); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { t.Fatalf("expected state checksum error: got %s", err) } // put the old state in place of the new, without updating the checksum - if err := client2.Put(oldState.Bytes()); err != nil { + if err := client2.Put(ctx, oldState.Bytes()); err != nil { t.Fatal(err) } // fetching the wrong state through client1 should now error out due to a // mismatched checksum. - if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { + if _, err := client1.Get(ctx); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { t.Fatalf("expected state checksum error: got %s", err) } // update the state with the correct one after we Get again testChecksumHook = func() { - if err := client2.Put(newState.Bytes()); err != nil { + if err := client2.Put(ctx, newState.Bytes()); err != nil { t.Fatal(err) } testChecksumHook = nil @@ -387,7 +387,7 @@ func TestRemoteClient_stateChecksum(t *testing.T) { // this final Get will fail to fail the checksum verification, the above // callback will update the state with the correct version, and Get should // retry automatically. - if _, err := client1.Get(); err != nil { + if _, err := client1.Get(ctx); err != nil { t.Fatal(err) } } diff --git a/internal/backend/remote-state/pg/backend_state.go b/internal/backend/remote-state/pg/backend_state.go index 6ce75f09f5..21b6cc9a32 100644 --- a/internal/backend/remote-state/pg/backend_state.go +++ b/internal/backend/remote-state/pg/backend_state.go @@ -103,7 +103,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err err = lockUnlock(err) return nil, err } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(ctx, nil); err != nil { err = lockUnlock(err) return nil, err } diff --git a/internal/backend/remote-state/pg/backend_test.go b/internal/backend/remote-state/pg/backend_test.go index 115bfc4542..45a971b31b 100644 --- a/internal/backend/remote-state/pg/backend_test.go +++ b/internal/backend/remote-state/pg/backend_test.go @@ -475,7 +475,9 @@ func TestBackendConcurrentLock(t *testing.T) { t.Fatalf("failed to lock first state: %v", err) } - if err = s1.PersistState(nil); err != nil { + ctx := context.Background() + + if err = s1.PersistState(ctx, nil); err != nil { t.Fatalf("failed to persist state: %v", err) } @@ -488,7 +490,7 @@ func TestBackendConcurrentLock(t *testing.T) { t.Fatalf("failed to lock second state: %v", err) } - if err = s2.PersistState(nil); err != nil { + if err = s2.PersistState(ctx, nil); err != nil { t.Fatalf("failed to persist state: %v", err) } diff --git a/internal/backend/remote-state/pg/client.go b/internal/backend/remote-state/pg/client.go index 796cace604..6281555cb2 100644 --- a/internal/backend/remote-state/pg/client.go +++ b/internal/backend/remote-state/pg/client.go @@ -24,9 +24,9 @@ type RemoteClient struct { info *statemgr.LockInfo } -func (c *RemoteClient) Get() (*remote.Payload, error) { +func (c *RemoteClient) Get(ctx context.Context) (*remote.Payload, error) { query := `SELECT data FROM %s.%s WHERE name = $1` - row := c.Client.QueryRow(fmt.Sprintf(query, c.SchemaName, statesTableName), c.Name) + row := c.Client.QueryRowContext(ctx, fmt.Sprintf(query, c.SchemaName, statesTableName), c.Name) var data []byte err := row.Scan(&data) switch { @@ -44,11 +44,11 @@ func (c *RemoteClient) Get() (*remote.Payload, error) { } } -func (c *RemoteClient) Put(data []byte) error { +func (c *RemoteClient) Put(ctx context.Context, data []byte) error { query := `INSERT INTO %s.%s (name, data) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET data = $2 WHERE %s.name = $1` - _, err := c.Client.Exec(fmt.Sprintf(query, c.SchemaName, statesTableName, statesTableName), c.Name, data) + _, err := c.Client.ExecContext(ctx, fmt.Sprintf(query, c.SchemaName, statesTableName, statesTableName), c.Name, data) if err != nil { return err } diff --git a/internal/backend/remote-state/s3/backend_state.go b/internal/backend/remote-state/s3/backend_state.go index b04d638e4b..cac11c8cb7 100644 --- a/internal/backend/remote-state/s3/backend_state.go +++ b/internal/backend/remote-state/s3/backend_state.go @@ -184,7 +184,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err // Grab the value // This is to ensure that no one beat us to writing a state between // the `exists` check and taking the lock. - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { err = lockUnlock(err) return nil, err } @@ -195,7 +195,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err err = lockUnlock(err) return nil, err } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(ctx, nil); err != nil { err = lockUnlock(err) return nil, err } diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index 9dea04c8f1..dc8def70e5 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -1060,7 +1060,7 @@ func TestBackendExtraPaths(t *testing.T) { if err := stateMgr.WriteState(s1); err != nil { t.Fatal(err) } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(ctx, nil); err != nil { t.Fatal(err) } @@ -1072,7 +1072,7 @@ func TestBackendExtraPaths(t *testing.T) { if err := stateMgr2.WriteState(s2); err != nil { t.Fatal(err) } - if err := stateMgr2.PersistState(nil); err != nil { + if err := stateMgr2.PersistState(ctx, nil); err != nil { t.Fatal(err) } @@ -1087,7 +1087,7 @@ func TestBackendExtraPaths(t *testing.T) { if err := stateMgr.WriteState(states.NewState()); err != nil { t.Fatal(err) } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(ctx, nil); err != nil { t.Fatal(err) } if err := checkStateList(b, []string{"default", "s1", "s2"}); err != nil { @@ -1099,7 +1099,7 @@ func TestBackendExtraPaths(t *testing.T) { if err := stateMgr.WriteState(states.NewState()); err != nil { t.Fatal(err) } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(ctx, nil); err != nil { t.Fatal(err) } if err := checkStateList(b, []string{"default", "s1", "s2"}); err != nil { @@ -1125,7 +1125,7 @@ func TestBackendExtraPaths(t *testing.T) { if err != nil { t.Fatal(err) } - if err := s2Mgr.RefreshState(); err != nil { + if err := s2Mgr.RefreshState(ctx); err != nil { t.Fatal(err) } @@ -1140,7 +1140,7 @@ func TestBackendExtraPaths(t *testing.T) { if err := stateMgr.WriteState(states.NewState()); err != nil { t.Fatal(err) } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(ctx, nil); err != nil { t.Fatal(err) } @@ -1149,7 +1149,7 @@ func TestBackendExtraPaths(t *testing.T) { if err != nil { t.Fatal(err) } - if err := s2Mgr.RefreshState(); err != nil { + if err := s2Mgr.RefreshState(ctx); err != nil { t.Fatal(err) } @@ -1183,7 +1183,7 @@ func TestBackendPrefixInWorkspace(t *testing.T) { if err != nil { t.Fatal(err) } - if err := sMgr.RefreshState(); err != nil { + if err := sMgr.RefreshState(ctx); err != nil { t.Fatal(err) } diff --git a/internal/backend/remote-state/s3/client.go b/internal/backend/remote-state/s3/client.go index d067d4fd1a..1c5388b21d 100644 --- a/internal/backend/remote-state/s3/client.go +++ b/internal/backend/remote-state/s3/client.go @@ -58,8 +58,7 @@ var ( // test hook called when checksums don't match var testChecksumHook func() -func (c *RemoteClient) Get() (payload *remote.Payload, err error) { - ctx := context.TODO() +func (c *RemoteClient) Get(ctx context.Context) (payload *remote.Payload, err error) { deadline := time.Now().Add(consistencyRetryTimeout) // If we have a checksum, and the returned payload doesn't match, we retry @@ -154,7 +153,7 @@ func (c *RemoteClient) get(ctx context.Context) (*remote.Payload, error) { return payload, nil } -func (c *RemoteClient) Put(data []byte) error { +func (c *RemoteClient) Put(ctx context.Context, data []byte) error { contentType := "application/json" contentLength := int64(len(data)) @@ -185,7 +184,6 @@ func (c *RemoteClient) Put(data []byte) error { log.Printf("[DEBUG] Uploading remote state to S3: %#v", i) - ctx := context.TODO() _, err := c.s3Client.PutObject(ctx, i) if err != nil { return fmt.Errorf("failed to upload state: %w", err) diff --git a/internal/backend/remote-state/s3/client_test.go b/internal/backend/remote-state/s3/client_test.go index 58025be35e..1e4575396f 100644 --- a/internal/backend/remote-state/s3/client_test.go +++ b/internal/backend/remote-state/s3/client_test.go @@ -261,22 +261,22 @@ func TestRemoteClient_stateChecksum(t *testing.T) { client2 := s2.(*remote.State).Client // write the new state through client2 so that there is no checksum yet - if err := client2.Put(newState.Bytes()); err != nil { + if err := client2.Put(ctx, newState.Bytes()); err != nil { t.Fatal(err) } // verify that we can pull a state without a checksum - if _, err := client1.Get(); err != nil { + if _, err := client1.Get(ctx); err != nil { t.Fatal(err) } // write the new state back with its checksum - if err := client1.Put(newState.Bytes()); err != nil { + if err := client1.Put(ctx, newState.Bytes()); err != nil { t.Fatal(err) } // put an empty state in place to check for panics during get - if err := client2.Put([]byte{}); err != nil { + if err := client2.Put(ctx, []byte{}); err != nil { t.Fatal(err) } @@ -292,24 +292,24 @@ func TestRemoteClient_stateChecksum(t *testing.T) { // fetching an empty state through client1 should now error out due to a // mismatched checksum. - if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { + if _, err := client1.Get(ctx); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { t.Fatalf("expected state checksum error: got %s", err) } // put the old state in place of the new, without updating the checksum - if err := client2.Put(oldState.Bytes()); err != nil { + if err := client2.Put(ctx, oldState.Bytes()); err != nil { t.Fatal(err) } // fetching the wrong state through client1 should now error out due to a // mismatched checksum. - if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { + if _, err := client1.Get(ctx); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { t.Fatalf("expected state checksum error: got %s", err) } // update the state with the correct one after we Get again testChecksumHook = func() { - if err := client2.Put(newState.Bytes()); err != nil { + if err := client2.Put(ctx, newState.Bytes()); err != nil { t.Fatal(err) } testChecksumHook = nil @@ -320,7 +320,7 @@ func TestRemoteClient_stateChecksum(t *testing.T) { // this final Get will fail to fail the checksum verification, the above // callback will update the state with the correct version, and Get should // retry automatically. - if _, err := client1.Get(); err != nil { + if _, err := client1.Get(ctx); err != nil { t.Fatal(err) } } diff --git a/internal/backend/remote/backend_context.go b/internal/backend/remote/backend_context.go index 0e3708f207..67771054bf 100644 --- a/internal/backend/remote/backend_context.go +++ b/internal/backend/remote/backend_context.go @@ -59,7 +59,7 @@ func (b *Remote) LocalRun(op *backend.Operation) (*backend.LocalRun, statemgr.Fu }() log.Printf("[TRACE] backend/remote: reading remote state for workspace %q", remoteWorkspaceName) - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { diags = diags.Append(fmt.Errorf("error loading state: %w", err)) return nil, nil, diags } diff --git a/internal/backend/remote/backend_state.go b/internal/backend/remote/backend_state.go index 15687a1961..230228e361 100644 --- a/internal/backend/remote/backend_state.go +++ b/internal/backend/remote/backend_state.go @@ -32,9 +32,7 @@ type remoteClient struct { } // Get the remote state. -func (r *remoteClient) Get() (*remote.Payload, error) { - ctx := context.Background() - +func (r *remoteClient) Get(ctx context.Context) (*remote.Payload, error) { sv, err := r.client.StateVersions.ReadCurrent(ctx, r.workspace.ID) if err != nil { if err == tfe.ErrResourceNotFound { @@ -89,9 +87,7 @@ func (r *remoteClient) uploadStateFallback(ctx context.Context, stateFile *state } // Put the remote state. -func (r *remoteClient) Put(state []byte) error { - ctx := context.Background() - +func (r *remoteClient) Put(ctx context.Context, state []byte) error { // Read the raw state into a OpenTofu state. stateFile, err := statefile.Read(bytes.NewReader(state)) if err != nil { diff --git a/internal/backend/remote/backend_state_test.go b/internal/backend/remote/backend_state_test.go index 66fadda1b4..b6a8978c45 100644 --- a/internal/backend/remote/backend_state_test.go +++ b/internal/backend/remote/backend_state_test.go @@ -60,7 +60,7 @@ func TestRemoteClient_Put_withRunID(t *testing.T) { // Store the new state to verify (this will be done // by the mock that is used) that the run ID is set. - if err := client.Put(buf.Bytes()); err != nil { + if err := client.Put(context.Background(), buf.Bytes()); err != nil { t.Fatalf("expected no error, got %v", err) } } diff --git a/internal/backend/testing.go b/internal/backend/testing.go index 5000f13cda..17b8431202 100644 --- a/internal/backend/testing.go +++ b/internal/backend/testing.go @@ -111,7 +111,7 @@ func TestBackendStates(t *testing.T, b Backend) { if err != nil { t.Fatalf("error: %s", err) } - if err := foo.RefreshState(); err != nil { + if err := foo.RefreshState(ctx); err != nil { t.Fatalf("bad: %s", err) } if v := foo.State(); v.HasManagedResourceInstanceObjects() { @@ -122,7 +122,7 @@ func TestBackendStates(t *testing.T, b Backend) { if err != nil { t.Fatalf("error: %s", err) } - if err := bar.RefreshState(); err != nil { + if err := bar.RefreshState(ctx); err != nil { t.Fatalf("bad: %s", err) } if v := bar.State(); v.HasManagedResourceInstanceObjects() { @@ -140,7 +140,7 @@ func TestBackendStates(t *testing.T, b Backend) { if err := foo.WriteState(fooState); err != nil { t.Fatal("error writing foo state:", err) } - if err := foo.PersistState(nil); err != nil { + if err := foo.PersistState(ctx, nil); err != nil { t.Fatal("error persisting foo state:", err) } @@ -168,12 +168,12 @@ func TestBackendStates(t *testing.T, b Backend) { if err := bar.WriteState(barState); err != nil { t.Fatalf("bad: %s", err) } - if err := bar.PersistState(nil); err != nil { + if err := bar.PersistState(ctx, nil); err != nil { t.Fatalf("bad: %s", err) } // verify that foo is unchanged with the existing state manager - if err := foo.RefreshState(); err != nil { + if err := foo.RefreshState(ctx); err != nil { t.Fatal("error refreshing foo:", err) } fooState = foo.State() @@ -186,7 +186,7 @@ func TestBackendStates(t *testing.T, b Backend) { if err != nil { t.Fatal("error re-fetching state:", err) } - if err := foo.RefreshState(); err != nil { + if err := foo.RefreshState(ctx); err != nil { t.Fatal("error refreshing foo:", err) } fooState = foo.State() @@ -199,7 +199,7 @@ func TestBackendStates(t *testing.T, b Backend) { if err != nil { t.Fatal("error re-fetching state:", err) } - if err := bar.RefreshState(); err != nil { + if err := bar.RefreshState(ctx); err != nil { t.Fatal("error refreshing bar:", err) } barState = bar.State() @@ -243,7 +243,7 @@ func TestBackendStates(t *testing.T, b Backend) { if err != nil { t.Fatalf("error: %s", err) } - if err := foo.RefreshState(); err != nil { + if err := foo.RefreshState(ctx); err != nil { t.Fatalf("bad: %s", err) } if v := foo.State(); v.HasManagedResourceInstanceObjects() { @@ -318,7 +318,7 @@ func testLocksInWorkspace(t *testing.T, b1, b2 Backend, testForceUnlock bool, wo if err != nil { t.Fatalf("error: %s", err) } - if err := b1StateMgr.RefreshState(); err != nil { + if err := b1StateMgr.RefreshState(ctx); err != nil { t.Fatalf("bad: %s", err) } @@ -334,7 +334,7 @@ func testLocksInWorkspace(t *testing.T, b1, b2 Backend, testForceUnlock bool, wo if err != nil { t.Fatalf("error: %s", err) } - if err := b2StateMgr.RefreshState(); err != nil { + if err := b2StateMgr.RefreshState(ctx); err != nil { t.Fatalf("bad: %s", err) } diff --git a/internal/builtin/providers/tf/data_source_state.go b/internal/builtin/providers/tf/data_source_state.go index 4932df2bb5..e54bcd708d 100644 --- a/internal/builtin/providers/tf/data_source_state.go +++ b/internal/builtin/providers/tf/data_source_state.go @@ -143,7 +143,7 @@ func dataSourceRemoteStateRead(d cty.Value) (cty.Value, tfdiags.Diagnostics) { return cty.NilVal, diags } - if err := state.RefreshState(); err != nil { + if err := state.RefreshState(ctx); err != nil { diags = diags.Append(err) return cty.NilVal, diags } diff --git a/internal/cloud/backend_context.go b/internal/cloud/backend_context.go index 09c95ecd64..79238efa0c 100644 --- a/internal/cloud/backend_context.go +++ b/internal/cloud/backend_context.go @@ -58,7 +58,7 @@ func (b *Cloud) LocalRun(op *backend.Operation) (*backend.LocalRun, statemgr.Ful }() log.Printf("[TRACE] cloud: reading remote state for workspace %q", remoteWorkspaceName) - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { diags = diags.Append(fmt.Errorf("error loading state: %w", err)) return nil, nil, diags } diff --git a/internal/cloud/state.go b/internal/cloud/state.go index dc018a61c3..dc965d8a3b 100644 --- a/internal/cloud/state.go +++ b/internal/cloud/state.go @@ -162,7 +162,7 @@ func (s *State) WriteState(state *states.State) error { } // PersistState uploads a snapshot of the latest state as a StateVersion to Terraform Cloud -func (s *State) PersistState(schemas *tofu.Schemas) error { +func (s *State) PersistState(ctx context.Context, schemas *tofu.Schemas) error { s.mu.Lock() defer s.mu.Unlock() @@ -182,7 +182,7 @@ func (s *State) PersistState(schemas *tofu.Schemas) error { // We might be writing a new state altogether, but before we do that // we'll check to make sure there isn't already a snapshot present // that we ought to be updating. - err := s.refreshState() + err := s.refreshState(ctx) if err != nil { return fmt.Errorf("failed checking for existing remote state: %w", err) } @@ -229,7 +229,7 @@ func (s *State) PersistState(schemas *tofu.Schemas) error { return fmt.Errorf("failed to marshal outputs to json: %w", err) } - err = s.uploadState(s.lineage, s.serial, s.forcePush, buf.Bytes(), jsonState, jsonStateOutputs) + err = s.uploadState(ctx, s.lineage, s.serial, s.forcePush, buf.Bytes(), jsonState, jsonStateOutputs) if err != nil { s.stateUploadErr = true return fmt.Errorf("error uploading state: %w", err) @@ -293,9 +293,7 @@ func (s *State) uploadStateFallback(ctx context.Context, lineage string, serial return err } -func (s *State) uploadState(lineage string, serial uint64, isForcePush bool, state, jsonState, jsonStateOutputs []byte) error { - ctx := context.Background() - +func (s *State) uploadState(ctx context.Context, lineage string, serial uint64, isForcePush bool, state, jsonState, jsonStateOutputs []byte) error { options := tfe.StateVersionUploadOptions{ StateVersionCreateOptions: tfe.StateVersionCreateOptions{ Lineage: tfe.String(lineage), @@ -362,17 +360,17 @@ func (s *State) Lock(info *statemgr.LockInfo) (string, error) { } // statemgr.Refresher impl. -func (s *State) RefreshState() error { +func (s *State) RefreshState(ctx context.Context) error { s.mu.Lock() defer s.mu.Unlock() - return s.refreshState() + return s.refreshState(ctx) } // refreshState is the main implementation of RefreshState, but split out so // that we can make internal calls to it from methods that are already holding // the s.mu lock. -func (s *State) refreshState() error { - payload, err := s.getStatePayload() +func (s *State) refreshState(ctx context.Context) error { + payload, err := s.getStatePayload(ctx) if err != nil { return err } @@ -402,9 +400,7 @@ func (s *State) refreshState() error { return nil } -func (s *State) getStatePayload() (*remote.Payload, error) { - ctx := context.Background() - +func (s *State) getStatePayload(ctx context.Context) (*remote.Payload, error) { // Check the x-terraform-snapshot-interval header to see if it has a non-empty // value which would indicate snapshots are enabled ctx = tfe.ContextWithResponseHeaderHook(ctx, s.readSnapshotIntervalHeader) @@ -516,9 +512,7 @@ func (s *State) Delete(ctx context.Context, force bool) error { } // GetRootOutputValues fetches output values from Terraform Cloud -func (s *State) GetRootOutputValues() (map[string]*states.OutputValue, error) { - ctx := context.Background() - +func (s *State) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { so, err := s.tfeClient.StateVersionOutputs.ReadCurrent(ctx, s.workspace.ID) if err != nil { @@ -535,7 +529,7 @@ func (s *State) GetRootOutputValues() (map[string]*states.OutputValue, error) { // requires a higher level of authorization. log.Printf("[DEBUG] falling back to reading full state") - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { return nil, fmt.Errorf("failed to load state: %w", err) } diff --git a/internal/cloud/state_test.go b/internal/cloud/state_test.go index d3f8a32c80..ab6f9707ba 100644 --- a/internal/cloud/state_test.go +++ b/internal/cloud/state_test.go @@ -41,7 +41,10 @@ func TestState_GetRootOutputValues(t *testing.T) { state := &State{tfeClient: b.client, organization: b.organization, workspace: &tfe.Workspace{ ID: "ws-abcd", }} - outputs, err := state.GetRootOutputValues() + + ctx := context.Background() + + outputs, err := state.GetRootOutputValues(ctx) if err != nil { t.Fatalf("error returned from GetRootOutputValues: %s", err) @@ -115,11 +118,13 @@ func TestState(t *testing.T) { } }`) - if err := state.uploadState(state.lineage, state.serial, state.forcePush, data, jsonState, jsonStateOutputs); err != nil { + ctx := context.Background() + + if err := state.uploadState(ctx, state.lineage, state.serial, state.forcePush, data, jsonState, jsonStateOutputs); err != nil { t.Fatalf("put: %s", err) } - payload, err := state.getStatePayload() + payload, err := state.getStatePayload(ctx) if err != nil { t.Fatalf("get: %s", err) } @@ -127,13 +132,11 @@ func TestState(t *testing.T) { t.Fatalf("expected full state %q\n\ngot: %q", string(payload.Data), string(data)) } - ctx := context.Background() - if err := state.Delete(ctx, true); err != nil { t.Fatalf("delete: %s", err) } - p, err := state.getStatePayload() + p, err := state.getStatePayload(ctx) if err != nil { t.Fatalf("get: %s", err) } @@ -277,7 +280,7 @@ func TestState_PersistState(t *testing.T) { t.Fatal("expected nil initial readState") } - err := cloudState.PersistState(nil) + err := cloudState.PersistState(context.Background(), nil) if err != nil { t.Fatalf("expected no error, got %q", err) } @@ -337,7 +340,9 @@ func TestState_PersistState(t *testing.T) { } cloudState.tfeClient = client - err = cloudState.RefreshState() + ctx := context.Background() + + err = cloudState.RefreshState(ctx) if err != nil { t.Fatal(err) } @@ -348,7 +353,7 @@ func TestState_PersistState(t *testing.T) { ) })) - err = cloudState.PersistState(nil) + err = cloudState.PersistState(ctx, nil) if err != nil { t.Fatal(err) } diff --git a/internal/command/import.go b/internal/command/import.go index ac44d31b2f..7c225e17f2 100644 --- a/internal/command/import.go +++ b/internal/command/import.go @@ -4,6 +4,7 @@ package command import ( + "context" "errors" "fmt" "log" @@ -269,7 +270,10 @@ func (c *ImportCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf("Error writing state file: %s", err)) return 1 } - if err := state.PersistState(schemas); err != nil { + + ctx := context.TODO() + + if err := state.PersistState(ctx, schemas); err != nil { c.Ui.Error(fmt.Sprintf("Error writing state file: %s", err)) return 1 } diff --git a/internal/command/init.go b/internal/command/init.go index 42b5832e10..2e7fc3743e 100644 --- a/internal/command/init.go +++ b/internal/command/init.go @@ -225,7 +225,7 @@ func (c *InitCommand) Run(args []string) int { return 1 } - if err := sMgr.RefreshState(); err != nil { + if err := sMgr.RefreshState(ctx); err != nil { c.Ui.Error(fmt.Sprintf("Error refreshing state: %s", err)) return 1 } diff --git a/internal/command/meta_backend.go b/internal/command/meta_backend.go index f86d4f86fd..82cf8865ec 100644 --- a/internal/command/meta_backend.go +++ b/internal/command/meta_backend.go @@ -986,7 +986,7 @@ func (m *Meta) backend_C_r_s(c *configs.Backend, cHash int, sMgr *clistate.Local diags = diags.Append(fmt.Errorf(errBackendLocalRead, err)) return nil, diags } - if err := localState.RefreshState(); err != nil { + if err := localState.RefreshState(ctx); err != nil { diags = diags.Append(fmt.Errorf(errBackendLocalRead, err)) return nil, diags } @@ -1049,7 +1049,7 @@ func (m *Meta) backend_C_r_s(c *configs.Backend, cHash int, sMgr *clistate.Local diags = diags.Append(fmt.Errorf(errBackendMigrateLocalDelete, err)) return nil, diags } - if err := localState.PersistState(nil); err != nil { + if err := localState.PersistState(ctx, nil); err != nil { diags = diags.Append(fmt.Errorf(errBackendMigrateLocalDelete, err)) return nil, diags } diff --git a/internal/command/meta_backend_migrate.go b/internal/command/meta_backend_migrate.go index 9b61d32352..e46b570a2e 100644 --- a/internal/command/meta_backend_migrate.go +++ b/internal/command/meta_backend_migrate.go @@ -267,7 +267,7 @@ func (m *Meta) backendMigrateState_s_s(opts *backendMigrateOpts) error { return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.SourceType, err) } - if err := sourceState.RefreshState(); err != nil { + if err := sourceState.RefreshState(ctx); err != nil { return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.SourceType, err) } @@ -315,7 +315,7 @@ func (m *Meta) backendMigrateState_s_s(opts *backendMigrateOpts) error { return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.DestinationType, err) } - if err := destinationState.RefreshState(); err != nil { + if err := destinationState.RefreshState(ctx); err != nil { return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.DestinationType, err) } @@ -368,12 +368,12 @@ func (m *Meta) backendMigrateState_s_s(opts *backendMigrateOpts) error { // We now own a lock, so double check that we have the version // corresponding to the lock. log.Print("[TRACE] backendMigrateState: refreshing source workspace state") - if err := sourceState.RefreshState(); err != nil { + if err := sourceState.RefreshState(ctx); err != nil { return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.SourceType, err) } log.Print("[TRACE] backendMigrateState: refreshing destination workspace state") - if err := destinationState.RefreshState(); err != nil { + if err := destinationState.RefreshState(ctx); err != nil { return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.SourceType, err) } @@ -451,8 +451,8 @@ func (m *Meta) backendMigrateState_s_s(opts *backendMigrateOpts) error { // The backend is currently handled before providers are installed during init, // so requiring schemas here could lead to a catch-22 where it requires some manual // intervention to proceed far enough for provider installation. To avoid this, - // when migrating to TFC backend, the initial JSON varient of state won't be generated and stored. - if err := destinationState.PersistState(nil); err != nil { + // when migrating to TFC backend, the initial JSON variant of state won't be generated and stored. + if err := destinationState.PersistState(ctx, nil); err != nil { return fmt.Errorf(strings.TrimSpace(errBackendStateCopy), opts.SourceType, opts.DestinationType, err) } @@ -602,7 +602,7 @@ func (m *Meta) backendMigrateTFC(opts *backendMigrateOpts) error { if err != nil { return err } - if err := sourceState.RefreshState(); err != nil { + if err := sourceState.RefreshState(ctx); err != nil { return err } if sourceState.State().Empty() { @@ -692,7 +692,7 @@ func (m *Meta) backendMigrateState_S_TFC(opts *backendMigrateOpts, sourceWorkspa errMigrateSingleLoadDefault), opts.SourceType, err) } // RefreshState is what actually pulls the state to be evaluated. - if err := sourceState.RefreshState(); err != nil { + if err := sourceState.RefreshState(ctx); err != nil { return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.SourceType, err) } diff --git a/internal/command/meta_backend_test.go b/internal/command/meta_backend_test.go index aef9f6a90c..d688c9f907 100644 --- a/internal/command/meta_backend_test.go +++ b/internal/command/meta_backend_test.go @@ -51,7 +51,7 @@ func TestMetaBackend_emptyDir(t *testing.T) { t.Fatalf("unexpected error: %s", err) } s.WriteState(testState()) - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -121,7 +121,7 @@ func TestMetaBackend_emptyWithDefaultState(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("err: %s", err) } if actual := s.State().String(); actual != testState().String() { @@ -142,7 +142,7 @@ func TestMetaBackend_emptyWithDefaultState(t *testing.T) { next := testState() next.RootModule().SetOutputValue("foo", cty.StringVal("bar"), false) s.WriteState(next) - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -194,7 +194,7 @@ func TestMetaBackend_emptyWithExplicitState(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("err: %s", err) } if actual := s.State().String(); actual != testState().String() { @@ -215,7 +215,7 @@ func TestMetaBackend_emptyWithExplicitState(t *testing.T) { next := testState() markStateForMatching(next, "bar") // just any change so it shows as different than before s.WriteState(next) - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -264,7 +264,7 @@ func TestMetaBackend_configureNew(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -277,7 +277,7 @@ func TestMetaBackend_configureNew(t *testing.T) { mark := markStateForMatching(state, "changing") s.WriteState(state) - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -462,7 +462,7 @@ func TestMetaBackend_configureNewWithStateNoMigrate(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } if state := s.State(); state != nil { @@ -507,7 +507,7 @@ func TestMetaBackend_configureNewWithStateExisting(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -523,7 +523,7 @@ func TestMetaBackend_configureNewWithStateExisting(t *testing.T) { mark := markStateForMatching(state, "changing") s.WriteState(state) - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -581,7 +581,7 @@ func TestMetaBackend_configureNewWithStateExistingNoMigrate(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -596,7 +596,7 @@ func TestMetaBackend_configureNewWithStateExistingNoMigrate(t *testing.T) { state = states.NewState() mark := markStateForMatching(state, "changing") s.WriteState(state) - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -648,7 +648,7 @@ func TestMetaBackend_configuredUnchanged(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -696,7 +696,7 @@ func TestMetaBackend_configuredChange(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -719,7 +719,7 @@ func TestMetaBackend_configuredChange(t *testing.T) { mark := markStateForMatching(state, "changing") s.WriteState(state) - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -784,7 +784,7 @@ func TestMetaBackend_reconfigureChange(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } newState := s.State() @@ -794,7 +794,7 @@ func TestMetaBackend_reconfigureChange(t *testing.T) { // verify that the old state is still there s = statemgr.NewFilesystem("local-state.tfstate") - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatal(err) } oldState := s.State() @@ -926,7 +926,7 @@ func TestMetaBackend_configuredChangeCopy(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -981,7 +981,7 @@ func TestMetaBackend_configuredChangeCopy_singleState(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1037,7 +1037,7 @@ func TestMetaBackend_configuredChangeCopy_multiToSingleDefault(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1093,7 +1093,7 @@ func TestMetaBackend_configuredChangeCopy_multiToSingle(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1169,7 +1169,7 @@ func TestMetaBackend_configuredChangeCopy_multiToSingleCurrentEnv(t *testing.T) if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1239,7 +1239,7 @@ func TestMetaBackend_configuredChangeCopy_multiToMulti(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1257,7 +1257,7 @@ func TestMetaBackend_configuredChangeCopy_multiToMulti(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1339,7 +1339,7 @@ func TestMetaBackend_configuredChangeCopy_multiToNoDefaultWithDefault(t *testing if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1415,7 +1415,7 @@ func TestMetaBackend_configuredChangeCopy_multiToNoDefaultWithoutDefault(t *test if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1470,7 +1470,7 @@ func TestMetaBackend_configuredUnset(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1492,7 +1492,7 @@ func TestMetaBackend_configuredUnset(t *testing.T) { // Write some state s.WriteState(testState()) - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -1534,7 +1534,7 @@ func TestMetaBackend_configuredUnsetCopy(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1552,7 +1552,7 @@ func TestMetaBackend_configuredUnsetCopy(t *testing.T) { // Write some state s.WriteState(testState()) - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -1604,7 +1604,7 @@ func TestMetaBackend_planLocal(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1633,7 +1633,7 @@ func TestMetaBackend_planLocal(t *testing.T) { mark := markStateForMatching(state, "changing") s.WriteState(state) - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -1707,7 +1707,7 @@ func TestMetaBackend_planLocalStatePath(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1736,7 +1736,7 @@ func TestMetaBackend_planLocalStatePath(t *testing.T) { mark := markStateForMatching(state, "changing") s.WriteState(state) - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -1798,7 +1798,7 @@ func TestMetaBackend_planLocalMatch(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1825,7 +1825,7 @@ func TestMetaBackend_planLocalMatch(t *testing.T) { mark := markStateForMatching(state, "changing") s.WriteState(state) - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("unexpected error: %s", err) } diff --git a/internal/command/output.go b/internal/command/output.go index 745c4d4e2d..958c1f2cc0 100644 --- a/internal/command/output.go +++ b/internal/command/output.go @@ -88,7 +88,7 @@ func (c *OutputCommand) Outputs(statePath string) (map[string]*states.OutputValu return nil, diags } - output, err := stateStore.GetRootOutputValues() + output, err := stateStore.GetRootOutputValues(ctx) if err != nil { return nil, diags.Append(err) } diff --git a/internal/command/providers.go b/internal/command/providers.go index c75f851416..ed47140741 100644 --- a/internal/command/providers.go +++ b/internal/command/providers.go @@ -108,7 +108,7 @@ func (c *ProvidersCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf("Failed to load state: %s", err)) return 1 } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { c.Ui.Error(fmt.Sprintf("Failed to load state: %s", err)) return 1 } diff --git a/internal/command/refresh_test.go b/internal/command/refresh_test.go index d83d8ceae7..46378f4416 100644 --- a/internal/command/refresh_test.go +++ b/internal/command/refresh_test.go @@ -5,6 +5,7 @@ package command import ( "bytes" + "context" "fmt" "os" "path/filepath" @@ -242,7 +243,7 @@ func TestRefresh_defaultState(t *testing.T) { statePath := testStateFile(t, originalState) localState := statemgr.NewFilesystem(statePath) - if err := localState.RefreshState(); err != nil { + if err := localState.RefreshState(context.Background()); err != nil { t.Fatal(err) } s := localState.State() diff --git a/internal/command/show.go b/internal/command/show.go index 8c7d03ee73..98a498ee3c 100644 --- a/internal/command/show.go +++ b/internal/command/show.go @@ -345,7 +345,7 @@ func getStateFromBackend(b backend.Backend, workspace string) (*statefile.File, } // Refresh the state store with the latest state snapshot from persistent storage - if err := stateStore.RefreshState(); err != nil { + if err := stateStore.RefreshState(ctx); err != nil { return nil, fmt.Errorf("Failed to load state: %w", err) } diff --git a/internal/command/state_list.go b/internal/command/state_list.go index 8fec24dbf7..591adfaaf7 100644 --- a/internal/command/state_list.go +++ b/internal/command/state_list.go @@ -62,7 +62,7 @@ func (c *StateListCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf(errStateLoadingState, err)) return 1 } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { c.Ui.Error(fmt.Sprintf("Failed to load state: %s", err)) return 1 } diff --git a/internal/command/state_mv.go b/internal/command/state_mv.go index 1352d5142e..4adaea557c 100644 --- a/internal/command/state_mv.go +++ b/internal/command/state_mv.go @@ -4,6 +4,7 @@ package command import ( + "context" "fmt" "strings" @@ -110,7 +111,9 @@ func (c *StateMvCommand) Run(args []string) int { }() } - if err := stateFromMgr.RefreshState(); err != nil { + ctx := context.TODO() + + if err := stateFromMgr.RefreshState(ctx); err != nil { c.Ui.Error(fmt.Sprintf("Failed to refresh source state: %s", err)) return 1 } @@ -148,7 +151,7 @@ func (c *StateMvCommand) Run(args []string) int { }() } - if err := stateToMgr.RefreshState(); err != nil { + if err := stateToMgr.RefreshState(ctx); err != nil { c.Ui.Error(fmt.Sprintf("Failed to refresh destination state: %s", err)) return 1 } @@ -410,7 +413,7 @@ func (c *StateMvCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } - if err := stateToMgr.PersistState(schemas); err != nil { + if err := stateToMgr.PersistState(ctx, schemas); err != nil { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } @@ -421,7 +424,7 @@ func (c *StateMvCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } - if err := stateFromMgr.PersistState(schemas); err != nil { + if err := stateFromMgr.PersistState(ctx, schemas); err != nil { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } diff --git a/internal/command/state_pull.go b/internal/command/state_pull.go index 6dc5c3658c..a1104fc7b8 100644 --- a/internal/command/state_pull.go +++ b/internal/command/state_pull.go @@ -56,7 +56,7 @@ func (c *StatePullCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf(errStateLoadingState, err)) return 1 } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { c.Ui.Error(fmt.Sprintf("Failed to refresh state: %s", err)) return 1 } diff --git a/internal/command/state_push.go b/internal/command/state_push.go index 86bf5c4ab7..9af99e7da5 100644 --- a/internal/command/state_push.go +++ b/internal/command/state_push.go @@ -120,7 +120,7 @@ func (c *StatePushCommand) Run(args []string) int { }() } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { c.Ui.Error(fmt.Sprintf("Failed to refresh destination state: %s", err)) return 1 } @@ -147,7 +147,7 @@ func (c *StatePushCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf("Failed to write state: %s", err)) return 1 } - if err := stateMgr.PersistState(schemas); err != nil { + if err := stateMgr.PersistState(ctx, schemas); err != nil { c.Ui.Error(fmt.Sprintf("Failed to persist state: %s", err)) return 1 } diff --git a/internal/command/state_push_test.go b/internal/command/state_push_test.go index 110f70abc5..cd3e898a0e 100644 --- a/internal/command/state_push_test.go +++ b/internal/command/state_push_test.go @@ -273,7 +273,7 @@ func TestStatePush_forceRemoteState(t *testing.T) { if err := sMgr.WriteState(states.NewState()); err != nil { t.Fatal(err) } - if err := sMgr.PersistState(nil); err != nil { + if err := sMgr.PersistState(ctx, nil); err != nil { t.Fatal(err) } diff --git a/internal/command/state_replace_provider.go b/internal/command/state_replace_provider.go index 45a167beba..91c0dcf98d 100644 --- a/internal/command/state_replace_provider.go +++ b/internal/command/state_replace_provider.go @@ -4,6 +4,7 @@ package command import ( + "context" "fmt" "strings" @@ -97,8 +98,10 @@ func (c *StateReplaceProviderCommand) Run(args []string) int { }() } + ctx := context.TODO() + // Refresh and load state - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { c.Ui.Error(fmt.Sprintf("Failed to refresh source state: %s", err)) return 1 } @@ -185,7 +188,7 @@ func (c *StateReplaceProviderCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } - if err := stateMgr.PersistState(schemas); err != nil { + if err := stateMgr.PersistState(ctx, schemas); err != nil { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } diff --git a/internal/command/state_rm.go b/internal/command/state_rm.go index af42a50549..d663b90123 100644 --- a/internal/command/state_rm.go +++ b/internal/command/state_rm.go @@ -4,6 +4,7 @@ package command import ( + "context" "fmt" "strings" @@ -67,7 +68,9 @@ func (c *StateRmCommand) Run(args []string) int { }() } - if err := stateMgr.RefreshState(); err != nil { + ctx := context.TODO() + + if err := stateMgr.RefreshState(ctx); err != nil { c.Ui.Error(fmt.Sprintf("Failed to refresh state: %s", err)) return 1 } @@ -134,7 +137,7 @@ func (c *StateRmCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } - if err := stateMgr.PersistState(schemas); err != nil { + if err := stateMgr.PersistState(ctx, schemas); err != nil { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } diff --git a/internal/command/state_show.go b/internal/command/state_show.go index 9d23c0b165..ab56790999 100644 --- a/internal/command/state_show.go +++ b/internal/command/state_show.go @@ -118,7 +118,7 @@ func (c *StateShowCommand) Run(args []string) int { c.Streams.Eprintln(fmt.Sprintf(errStateLoadingState, err)) return 1 } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { c.Streams.Eprintf("Failed to refresh state: %s\n", err) return 1 } diff --git a/internal/command/taint.go b/internal/command/taint.go index 79e7bf79ae..908f933e5b 100644 --- a/internal/command/taint.go +++ b/internal/command/taint.go @@ -111,7 +111,7 @@ func (c *TaintCommand) Run(args []string) int { }() } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { c.Ui.Error(fmt.Sprintf("Failed to load state: %s", err)) return 1 } @@ -186,7 +186,7 @@ func (c *TaintCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf("Error writing state file: %s", err)) return 1 } - if err := stateMgr.PersistState(schemas); err != nil { + if err := stateMgr.PersistState(ctx, schemas); err != nil { c.Ui.Error(fmt.Sprintf("Error writing state file: %s", err)) return 1 } diff --git a/internal/command/untaint.go b/internal/command/untaint.go index e17f880312..4d3b19b7c9 100644 --- a/internal/command/untaint.go +++ b/internal/command/untaint.go @@ -101,7 +101,7 @@ func (c *UntaintCommand) Run(args []string) int { }() } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { c.Ui.Error(fmt.Sprintf("Failed to load state: %s", err)) return 1 } @@ -186,7 +186,7 @@ func (c *UntaintCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf("Error writing state file: %s", err)) return 1 } - if err := stateMgr.PersistState(schemas); err != nil { + if err := stateMgr.PersistState(ctx, schemas); err != nil { c.Ui.Error(fmt.Sprintf("Error writing state file: %s", err)) return 1 } diff --git a/internal/command/workspace_command_test.go b/internal/command/workspace_command_test.go index 011241dd44..534baf0b3b 100644 --- a/internal/command/workspace_command_test.go +++ b/internal/command/workspace_command_test.go @@ -270,7 +270,7 @@ func TestWorkspace_createWithState(t *testing.T) { newPath := filepath.Join(local.DefaultWorkspaceDir, "test", DefaultStateFilename) envState := statemgr.NewFilesystem(newPath) - err = envState.RefreshState() + err = envState.RefreshState(context.Background()) if err != nil { t.Fatal(err) } diff --git a/internal/command/workspace_delete.go b/internal/command/workspace_delete.go index af1913adfc..7c4c00f738 100644 --- a/internal/command/workspace_delete.go +++ b/internal/command/workspace_delete.go @@ -125,7 +125,7 @@ func (c *WorkspaceDeleteCommand) Run(args []string) int { stateLocker = clistate.NewNoopLocker() } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(ctx); err != nil { // We need to release the lock before exit stateLocker.Unlock() c.Ui.Error(err.Error()) diff --git a/internal/command/workspace_new.go b/internal/command/workspace_new.go index ad861cc63e..31aa76d313 100644 --- a/internal/command/workspace_new.go +++ b/internal/command/workspace_new.go @@ -163,7 +163,7 @@ func (c *WorkspaceNewCommand) Run(args []string) int { c.Ui.Error(err.Error()) return 1 } - err = stateMgr.PersistState(nil) + err = stateMgr.PersistState(ctx, nil) if err != nil { c.Ui.Error(err.Error()) return 1 diff --git a/internal/states/remote/remote.go b/internal/states/remote/remote.go index 8322886c10..1271cf0675 100644 --- a/internal/states/remote/remote.go +++ b/internal/states/remote/remote.go @@ -13,8 +13,8 @@ import ( // driver. It supports dumb put/get/delete, and the higher level structs // handle persisting the state properly here. type Client interface { - Get() (*Payload, error) - Put([]byte) error + Get(context.Context) (*Payload, error) + Put(context.Context, []byte) error Delete(context.Context) error } diff --git a/internal/states/remote/remote_test.go b/internal/states/remote/remote_test.go index e461cb53c4..9da2a746c4 100644 --- a/internal/states/remote/remote_test.go +++ b/internal/states/remote/remote_test.go @@ -14,7 +14,7 @@ func TestRemoteClient_noPayload(t *testing.T) { s := &State{ Client: nilClient{}, } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(context.Background()); err != nil { t.Fatal("error refreshing empty remote state") } } @@ -22,9 +22,9 @@ func TestRemoteClient_noPayload(t *testing.T) { // nilClient returns nil for everything type nilClient struct{} -func (nilClient) Get() (*Payload, error) { return nil, nil } +func (nilClient) Get(context.Context) (*Payload, error) { return nil, nil } -func (c nilClient) Put([]byte) error { return nil } +func (c nilClient) Put(context.Context, []byte) error { return nil } func (c nilClient) Delete(context.Context) error { return nil } @@ -41,7 +41,7 @@ type mockClientRequest struct { Content map[string]interface{} } -func (c *mockClient) Get() (*Payload, error) { +func (c *mockClient) Get(context.Context) (*Payload, error) { c.appendLog("Get", c.current) if c.current == nil { return nil, nil @@ -53,7 +53,7 @@ func (c *mockClient) Get() (*Payload, error) { }, nil } -func (c *mockClient) Put(data []byte) error { +func (c *mockClient) Put(_ context.Context, data []byte) error { c.appendLog("Put", data) c.current = data return nil @@ -90,7 +90,7 @@ type mockClientForcePusher struct { log []mockClientRequest } -func (c *mockClientForcePusher) Get() (*Payload, error) { +func (c *mockClientForcePusher) Get(context.Context) (*Payload, error) { c.appendLog("Get", c.current) if c.current == nil { return nil, nil @@ -102,7 +102,7 @@ func (c *mockClientForcePusher) Get() (*Payload, error) { }, nil } -func (c *mockClientForcePusher) Put(data []byte) error { +func (c *mockClientForcePusher) Put(_ context.Context, data []byte) error { if c.force { c.appendLog("Force Put", data) } else { diff --git a/internal/states/remote/state.go b/internal/states/remote/state.go index f23f50d8ca..161a21111c 100644 --- a/internal/states/remote/state.go +++ b/internal/states/remote/state.go @@ -5,6 +5,7 @@ package remote import ( "bytes" + "context" "fmt" "log" "sync" @@ -60,8 +61,8 @@ func (s *State) State() *states.State { return s.state.DeepCopy() } -func (s *State) GetRootOutputValues() (map[string]*states.OutputValue, error) { - if err := s.RefreshState(); err != nil { +func (s *State) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { + if err := s.RefreshState(ctx); err != nil { return nil, fmt.Errorf("Failed to load state: %w", err) } @@ -125,17 +126,17 @@ func (s *State) WriteStateForMigration(f *statefile.File, force bool) error { } // statemgr.Refresher impl. -func (s *State) RefreshState() error { +func (s *State) RefreshState(ctx context.Context) error { s.mu.Lock() defer s.mu.Unlock() - return s.refreshState() + return s.refreshState(ctx) } // refreshState is the main implementation of RefreshState, but split out so // that we can make internal calls to it from methods that are already holding // the s.mu lock. -func (s *State) refreshState() error { - payload, err := s.Client.Get() +func (s *State) refreshState(ctx context.Context) error { + payload, err := s.Client.Get(ctx) if err != nil { return err } @@ -166,7 +167,7 @@ func (s *State) refreshState() error { } // statemgr.Persister impl. -func (s *State) PersistState(schemas *tofu.Schemas) error { +func (s *State) PersistState(ctx context.Context, schemas *tofu.Schemas) error { s.mu.Lock() defer s.mu.Unlock() @@ -186,7 +187,7 @@ func (s *State) PersistState(schemas *tofu.Schemas) error { // We might be writing a new state altogether, but before we do that // we'll check to make sure there isn't already a snapshot present // that we ought to be updating. - err := s.refreshState() + err := s.refreshState(ctx) if err != nil { return fmt.Errorf("failed checking for existing remote state: %w", err) } @@ -210,7 +211,7 @@ func (s *State) PersistState(schemas *tofu.Schemas) error { return err } - err = s.Client.Put(buf.Bytes()) + err = s.Client.Put(ctx, buf.Bytes()) if err != nil { return err } diff --git a/internal/states/remote/state_test.go b/internal/states/remote/state_test.go index 7a0bbe6458..c78435215c 100644 --- a/internal/states/remote/state_test.go +++ b/internal/states/remote/state_test.go @@ -4,6 +4,7 @@ package remote import ( + "context" "log" "sync" "testing" @@ -41,10 +42,12 @@ func TestStateRace(t *testing.T) { for i := 0; i < 100; i++ { wg.Add(1) go func() { + ctx := context.Background() + defer wg.Done() s.WriteState(current) - s.PersistState(nil) - s.RefreshState() + s.PersistState(ctx, nil) + s.RefreshState(ctx) }() } wg.Wait() @@ -329,11 +332,13 @@ func TestStatePersist(t *testing.T) { Client: &mockClient{}, } + ctx := context.Background() + // In normal use (during a OpenTofu operation) we always refresh and read // before any writes would happen, so we'll mimic that here for realism. // NB This causes a GET to be logged so the first item in the test cases // must account for this - if err := mgr.RefreshState(); err != nil { + if err := mgr.RefreshState(ctx); err != nil { t.Fatalf("failed to RefreshState: %s", err) } @@ -354,7 +359,7 @@ func TestStatePersist(t *testing.T) { if err := mgr.WriteState(s); err != nil { t.Fatalf("failed to WriteState for %q: %s", tc.name, err) } - if err := mgr.PersistState(nil); err != nil { + if err := mgr.PersistState(ctx, nil); err != nil { t.Fatalf("failed to PersistState for %q: %s", tc.name, err) } @@ -402,7 +407,7 @@ func TestState_GetRootOutputValues(t *testing.T) { }, } - outputs, err := mgr.GetRootOutputValues() + outputs, err := mgr.GetRootOutputValues(context.Background()) if err != nil { t.Errorf("Expected GetRootOutputValues to not return an error, but it returned %v", err) } @@ -513,11 +518,13 @@ func TestWriteStateForMigration(t *testing.T) { }, } + ctx := context.Background() + // In normal use (during a OpenTofu operation) we always refresh and read // before any writes would happen, so we'll mimic that here for realism. // NB This causes a GET to be logged so the first item in the test cases // must account for this - if err := mgr.RefreshState(); err != nil { + if err := mgr.RefreshState(ctx); err != nil { t.Fatalf("failed to RefreshState: %s", err) } @@ -557,7 +564,7 @@ func TestWriteStateForMigration(t *testing.T) { // At this point we should just do a normal write and persist // as would happen from the CLI mgr.WriteState(mgr.State()) - mgr.PersistState(nil) + mgr.PersistState(ctx, nil) if logIdx >= len(mockClient.log) { t.Fatalf("request lock and index are out of sync on %q: idx=%d len=%d", tc.name, logIdx, len(mockClient.log)) @@ -669,11 +676,13 @@ func TestWriteStateForMigrationWithForcePushClient(t *testing.T) { }, } + ctx := context.Background() + // In normal use (during a OpenTofu operation) we always refresh and read // before any writes would happen, so we'll mimic that here for realism. // NB This causes a GET to be logged so the first item in the test cases // must account for this - if err := mgr.RefreshState(); err != nil { + if err := mgr.RefreshState(ctx); err != nil { t.Fatalf("failed to RefreshState: %s", err) } @@ -723,7 +732,7 @@ func TestWriteStateForMigrationWithForcePushClient(t *testing.T) { // At this point we should just do a normal write and persist // as would happen from the CLI mgr.WriteState(mgr.State()) - mgr.PersistState(nil) + mgr.PersistState(ctx, nil) if logIdx >= len(mockClient.log) { t.Fatalf("request lock and index are out of sync on %q: idx=%d len=%d", tc.name, logIdx, len(mockClient.log)) diff --git a/internal/states/remote/testing.go b/internal/states/remote/testing.go index 4e9d657599..ca7cffcc77 100644 --- a/internal/states/remote/testing.go +++ b/internal/states/remote/testing.go @@ -23,11 +23,13 @@ func TestClient(t *testing.T, c Client) { } data := buf.Bytes() - if err := c.Put(data); err != nil { + ctx := context.Background() + + if err := c.Put(ctx, data); err != nil { t.Fatalf("put: %s", err) } - p, err := c.Get() + p, err := c.Get(ctx) if err != nil { t.Fatalf("get: %s", err) } @@ -35,13 +37,11 @@ func TestClient(t *testing.T, c Client) { t.Fatalf("expected full state %q\n\ngot: %q", string(p.Data), string(data)) } - ctx := context.Background() - if err := c.Delete(ctx); err != nil { t.Fatalf("delete: %s", err) } - p, err = c.Get() + p, err = c.Get(ctx) if err != nil { t.Fatalf("get: %s", err) } diff --git a/internal/states/statemgr/filesystem.go b/internal/states/statemgr/filesystem.go index 467b4861e5..be191c0464 100644 --- a/internal/states/statemgr/filesystem.go +++ b/internal/states/statemgr/filesystem.go @@ -5,6 +5,7 @@ package statemgr import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -90,7 +91,7 @@ func NewFilesystemBetweenPaths(readPath, writePath string) *Filesystem { } } -// SetBackupPath configures the receiever so that it will create a local +// SetBackupPath configures the receiver so that it will create a local // backup file of the next state snapshot it reads (in State) if a different // snapshot is subsequently written (in WriteState). Only one backup is // written for the lifetime of the object, unless reset as described below. @@ -226,18 +227,18 @@ func (s *Filesystem) writeState(state *states.State, meta *SnapshotMeta) error { // PersistState is an implementation of Persister that does nothing because // this type's Writer implementation does its own persistence. -func (s *Filesystem) PersistState(schemas *tofu.Schemas) error { +func (s *Filesystem) PersistState(_ context.Context, schemas *tofu.Schemas) error { return nil } // RefreshState is an implementation of Refresher. -func (s *Filesystem) RefreshState() error { +func (s *Filesystem) RefreshState(ctx context.Context) error { defer s.mutex()() return s.refreshState() } -func (s *Filesystem) GetRootOutputValues() (map[string]*states.OutputValue, error) { - err := s.RefreshState() +func (s *Filesystem) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { + err := s.RefreshState(ctx) if err != nil { return nil, err } diff --git a/internal/states/statemgr/filesystem_test.go b/internal/states/statemgr/filesystem_test.go index cb01b59c11..ce4432d094 100644 --- a/internal/states/statemgr/filesystem_test.go +++ b/internal/states/statemgr/filesystem_test.go @@ -4,6 +4,7 @@ package statemgr import ( + "context" "os" "os/exec" "path/filepath" @@ -288,7 +289,7 @@ func TestFilesystem_backupAndReadPath(t *testing.T) { func TestFilesystem_nonExist(t *testing.T) { defer testOverrideVersion(t, "1.2.3")() ls := NewFilesystem("ishouldntexist") - if err := ls.RefreshState(); err != nil { + if err := ls.RefreshState(context.Background()); err != nil { t.Fatalf("err: %s", err) } @@ -361,7 +362,7 @@ func testFilesystem(t *testing.T) *Filesystem { f.Close() ls := NewFilesystem(f.Name()) - if err := ls.RefreshState(); err != nil { + if err := ls.RefreshState(context.Background()); err != nil { t.Fatalf("initial refresh failed: %s", err) } @@ -403,7 +404,7 @@ func TestFilesystem_refreshWhileLocked(t *testing.T) { } }() - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(context.Background()); err != nil { t.Fatal(err) } @@ -416,7 +417,7 @@ func TestFilesystem_refreshWhileLocked(t *testing.T) { func TestFilesystem_GetRootOutputValues(t *testing.T) { fs := testFilesystem(t) - outputs, err := fs.GetRootOutputValues() + outputs, err := fs.GetRootOutputValues(context.Background()) if err != nil { t.Errorf("Expected GetRootOutputValues to not return an error, but it returned %v", err) } diff --git a/internal/states/statemgr/helper.go b/internal/states/statemgr/helper.go index cfad878a98..81c89aa2b7 100644 --- a/internal/states/statemgr/helper.go +++ b/internal/states/statemgr/helper.go @@ -7,6 +7,8 @@ package statemgr // operations done against full state managers. import ( + "context" + "github.com/opentofu/opentofu/internal/states" "github.com/opentofu/opentofu/internal/states/statefile" "github.com/opentofu/opentofu/internal/tofu" @@ -29,7 +31,7 @@ func NewStateFile() *statefile.File { // This is a wrapper around calling RefreshState and then State on the given // manager. func RefreshAndRead(mgr Storage) (*states.State, error) { - err := mgr.RefreshState() + err := mgr.RefreshState(context.Background()) if err != nil { return nil, err } @@ -53,5 +55,5 @@ func WriteAndPersist(mgr Storage, state *states.State, schemas *tofu.Schemas) er if err != nil { return err } - return mgr.PersistState(schemas) + return mgr.PersistState(context.Background(), schemas) } diff --git a/internal/states/statemgr/lock.go b/internal/states/statemgr/lock.go index 9eca1f04c5..38ab083402 100644 --- a/internal/states/statemgr/lock.go +++ b/internal/states/statemgr/lock.go @@ -4,6 +4,8 @@ package statemgr import ( + "context" + "github.com/opentofu/opentofu/internal/states" "github.com/opentofu/opentofu/internal/tofu" ) @@ -21,20 +23,20 @@ func (s *LockDisabled) State() *states.State { return s.Inner.State() } -func (s *LockDisabled) GetRootOutputValues() (map[string]*states.OutputValue, error) { - return s.Inner.GetRootOutputValues() +func (s *LockDisabled) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { + return s.Inner.GetRootOutputValues(ctx) } func (s *LockDisabled) WriteState(v *states.State) error { return s.Inner.WriteState(v) } -func (s *LockDisabled) RefreshState() error { - return s.Inner.RefreshState() +func (s *LockDisabled) RefreshState(ctx context.Context) error { + return s.Inner.RefreshState(ctx) } -func (s *LockDisabled) PersistState(schemas *tofu.Schemas) error { - return s.Inner.PersistState(schemas) +func (s *LockDisabled) PersistState(ctx context.Context, schemas *tofu.Schemas) error { + return s.Inner.PersistState(ctx, schemas) } func (s *LockDisabled) Lock(info *LockInfo) (string, error) { diff --git a/internal/states/statemgr/persistent.go b/internal/states/statemgr/persistent.go index 536c8a8587..a3417a6712 100644 --- a/internal/states/statemgr/persistent.go +++ b/internal/states/statemgr/persistent.go @@ -4,6 +4,8 @@ package statemgr import ( + "context" + version "github.com/hashicorp/go-version" "github.com/opentofu/opentofu/internal/states" @@ -31,7 +33,7 @@ type Persistent interface { // to differentiate reading the state and reading the outputs within the state. type OutputReader interface { // GetRootOutputValues fetches the root module output values from state or another source - GetRootOutputValues() (map[string]*states.OutputValue, error) + GetRootOutputValues(context.Context) (map[string]*states.OutputValue, error) } // Refresher is the interface for managers that can read snapshots from @@ -61,7 +63,7 @@ type Refresher interface { // return only a subset of what was written. Callers must assume that // ephemeral portions of the state may be unpopulated after calling // RefreshState. - RefreshState() error + RefreshState(context.Context) error } // Persister is the interface for managers that can write snapshots to @@ -81,7 +83,7 @@ type Refresher interface { // state. For example, when representing state in an external JSON // representation. type Persister interface { - PersistState(*tofu.Schemas) error + PersistState(context.Context, *tofu.Schemas) error } // PersistentMeta is an optional extension to Persistent that allows inspecting diff --git a/internal/states/statemgr/statemgr_fake.go b/internal/states/statemgr/statemgr_fake.go index 11f23c837c..d8001e25ab 100644 --- a/internal/states/statemgr/statemgr_fake.go +++ b/internal/states/statemgr/statemgr_fake.go @@ -4,6 +4,7 @@ package statemgr import ( + "context" "errors" "sync" @@ -61,15 +62,15 @@ func (m *fakeFull) WriteState(s *states.State) error { return m.t.WriteState(s) } -func (m *fakeFull) RefreshState() error { +func (m *fakeFull) RefreshState(context.Context) error { return m.t.WriteState(m.fakeP.State()) } -func (m *fakeFull) PersistState(schemas *tofu.Schemas) error { +func (m *fakeFull) PersistState(_ context.Context, schemas *tofu.Schemas) error { return m.fakeP.WriteState(m.t.State()) } -func (m *fakeFull) GetRootOutputValues() (map[string]*states.OutputValue, error) { +func (m *fakeFull) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { return m.State().RootModule().OutputValues, nil } @@ -119,7 +120,7 @@ func (m *fakeErrorFull) State() *states.State { return nil } -func (m *fakeErrorFull) GetRootOutputValues() (map[string]*states.OutputValue, error) { +func (m *fakeErrorFull) GetRootOutputValues(context.Context) (map[string]*states.OutputValue, error) { return nil, errors.New("fake state manager error") } @@ -127,11 +128,11 @@ func (m *fakeErrorFull) WriteState(s *states.State) error { return errors.New("fake state manager error") } -func (m *fakeErrorFull) RefreshState() error { +func (m *fakeErrorFull) RefreshState(context.Context) error { return errors.New("fake state manager error") } -func (m *fakeErrorFull) PersistState(schemas *tofu.Schemas) error { +func (m *fakeErrorFull) PersistState(_ context.Context, schemas *tofu.Schemas) error { return errors.New("fake state manager error") } diff --git a/internal/states/statemgr/testing.go b/internal/states/statemgr/testing.go index 0cf943f8a3..614f336d26 100644 --- a/internal/states/statemgr/testing.go +++ b/internal/states/statemgr/testing.go @@ -4,6 +4,7 @@ package statemgr import ( + "context" "reflect" "testing" @@ -25,7 +26,9 @@ import ( func TestFull(t *testing.T, s Full) { t.Helper() - if err := s.RefreshState(); err != nil { + ctx := context.Background() + + if err := s.RefreshState(ctx); err != nil { t.Fatalf("err: %s", err) } @@ -59,12 +62,12 @@ func TestFull(t *testing.T, s Full) { } // Test persistence - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("err: %s", err) } // Refresh if we got it - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { t.Fatalf("err: %s", err) } @@ -84,7 +87,7 @@ func TestFull(t *testing.T, s Full) { if err := s.WriteState(current); err != nil { t.Fatalf("err: %s", err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("err: %s", err) } @@ -107,7 +110,7 @@ func TestFull(t *testing.T, s Full) { if err := s.WriteState(current); err != nil { t.Fatalf("err: %s", err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(ctx, nil); err != nil { t.Fatalf("err: %s", err) }