diff --git a/state/backup.go b/state/backup.go index c357bba495..047258f4db 100644 --- a/state/backup.go +++ b/state/backup.go @@ -1,12 +1,17 @@ package state -import "github.com/hashicorp/terraform/terraform" +import ( + "sync" + + "github.com/hashicorp/terraform/terraform" +) // BackupState wraps a State that backs up the state on the first time that // a WriteState or PersistState is called. // // If Path exists, it will be overwritten. type BackupState struct { + mu sync.Mutex Real State Path string @@ -22,6 +27,9 @@ func (s *BackupState) RefreshState() error { } func (s *BackupState) WriteState(state *terraform.State) error { + s.mu.Lock() + defer s.mu.Unlock() + if !s.done { if err := s.backup(); err != nil { return err @@ -32,6 +40,9 @@ func (s *BackupState) WriteState(state *terraform.State) error { } func (s *BackupState) PersistState() error { + s.mu.Lock() + defer s.mu.Unlock() + if !s.done { if err := s.backup(); err != nil { return err diff --git a/state/backup_test.go b/state/backup_test.go index 85f722863c..8ef0afec66 100644 --- a/state/backup_test.go +++ b/state/backup_test.go @@ -3,6 +3,7 @@ package state import ( "io/ioutil" "os" + "sync" "testing" ) @@ -31,3 +32,34 @@ func TestBackupState(t *testing.T) { t.Fatalf("bad: %d", fi.Size()) } } + +func TestBackupStateRace(t *testing.T) { + f, err := ioutil.TempFile("", "tf") + if err != nil { + t.Fatalf("err: %s", err) + } + f.Close() + defer os.Remove(f.Name()) + + ls := testLocalState(t) + defer os.Remove(ls.Path) + bs := &BackupState{ + Real: ls, + Path: f.Name(), + } + + current := TestStateInitial() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + bs.WriteState(current) + bs.PersistState() + bs.RefreshState() + }() + } + + wg.Wait() +} diff --git a/state/inmem.go b/state/inmem.go index 2bbfb3d44a..4e031896c6 100644 --- a/state/inmem.go +++ b/state/inmem.go @@ -10,18 +10,28 @@ import ( // InmemState is an in-memory state storage. type InmemState struct { + mu sync.Mutex state *terraform.State } func (s *InmemState) State() *terraform.State { + s.mu.Lock() + defer s.mu.Unlock() + return s.state.DeepCopy() } func (s *InmemState) RefreshState() error { + s.mu.Lock() + defer s.mu.Unlock() + return nil } func (s *InmemState) WriteState(state *terraform.State) error { + s.mu.Lock() + defer s.mu.Unlock() + state.IncrementSerialMaybe(s.state) s.state = state return nil diff --git a/state/local.go b/state/local.go index b4029267e5..5ce02ce57b 100644 --- a/state/local.go +++ b/state/local.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "os" "path/filepath" + "sync" "time" multierror "github.com/hashicorp/go-multierror" @@ -16,6 +17,8 @@ import ( // LocalState manages a state storage that is local to the filesystem. type LocalState struct { + mu sync.Mutex + // Path is the path to read the state from. PathOut is the path to // write the state to. If PathOut is not specified, Path will be used. // If PathOut already exists, it will be overwritten. @@ -42,6 +45,9 @@ type LocalState struct { // SetState will force a specific state in-memory for this local state. func (s *LocalState) SetState(state *terraform.State) { + s.mu.Lock() + defer s.mu.Unlock() + s.state = state s.readState = state } @@ -58,6 +64,9 @@ func (s *LocalState) State() *terraform.State { // // StateWriter impl. func (s *LocalState) WriteState(state *terraform.State) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.stateFileOut == nil { if err := s.createStateFiles(); err != nil { return nil @@ -99,6 +108,9 @@ func (s *LocalState) PersistState() error { // StateRefresher impl. func (s *LocalState) RefreshState() error { + s.mu.Lock() + defer s.mu.Unlock() + var reader io.Reader if !s.written { // we haven't written a state file yet, so load from Path @@ -141,6 +153,9 @@ func (s *LocalState) RefreshState() error { // Lock implements a local filesystem state.Locker. func (s *LocalState) Lock(info *LockInfo) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.stateFileOut == nil { if err := s.createStateFiles(); err != nil { return "", err @@ -170,6 +185,9 @@ func (s *LocalState) Lock(info *LockInfo) (string, error) { } func (s *LocalState) Unlock(id string) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.lockID == "" { return fmt.Errorf("LocalState not locked") } diff --git a/state/local_test.go b/state/local_test.go index 76abde1ce4..6333560282 100644 --- a/state/local_test.go +++ b/state/local_test.go @@ -4,6 +4,7 @@ import ( "io/ioutil" "os" "os/exec" + "sync" "testing" "github.com/hashicorp/terraform/terraform" @@ -15,6 +16,22 @@ func TestLocalState(t *testing.T) { TestState(t, ls) } +func TestLocalStateRace(t *testing.T) { + ls := testLocalState(t) + defer os.Remove(ls.Path) + + current := TestStateInitial() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ls.WriteState(current) + }() + } +} + func TestLocalStateLocks(t *testing.T) { s := testLocalState(t) defer os.Remove(s.Path)