From 9c431aee1b3a05fe228affc7e80806e21b54b634 Mon Sep 17 00:00:00 2001 From: James Bardin Date: Wed, 12 Apr 2017 13:30:49 -0400 Subject: [PATCH] only list environments when the keyName matches Prevent extra keys in the s3 envPrefix path from showing up as listed environments. Better handle keys containing slashes Add tests for unexpected keys in s3. --- backend/remote-state/s3/backend_state.go | 24 +++-- backend/remote-state/s3/backend_test.go | 132 ++++++++++++++++++++++- 2 files changed, 148 insertions(+), 8 deletions(-) diff --git a/backend/remote-state/s3/backend_state.go b/backend/remote-state/s3/backend_state.go index 2d745156e9..f7a4d337de 100644 --- a/backend/remote-state/s3/backend_state.go +++ b/backend/remote-state/s3/backend_state.go @@ -1,6 +1,7 @@ package s3 import ( + "errors" "fmt" "sort" "strings" @@ -30,29 +31,34 @@ func (b *Backend) States() ([]string, error) { return nil, err } - var envs []string + envs := []string{backend.DefaultStateName} for _, obj := range resp.Contents { - env := keyEnv(*obj.Key) + env := b.keyEnv(*obj.Key) if env != "" { envs = append(envs, env) } } - sort.Strings(envs) - envs = append([]string{backend.DefaultStateName}, envs...) + sort.Strings(envs[1:]) return envs, nil } // extract the env name from the S3 key -func keyEnv(key string) string { - parts := strings.Split(key, "/") +func (b *Backend) keyEnv(key string) string { + // we have 3 parts, the prefix, the env name, and the key name + parts := strings.SplitN(key, "/", 3) if len(parts) < 3 { // no env here return "" } + // shouldn't happen since we listed by prefix if parts[0] != keyEnvPrefix { - // not our key, so ignore + return "" + } + + // not our key, so don't include it in our listing + if parts[2] != b.keyName { return "" } @@ -78,6 +84,10 @@ func (b *Backend) DeleteState(name string) error { } func (b *Backend) State(name string) (state.State, error) { + if name == "" { + return nil, errors.New("missing state name") + } + client := &RemoteClient{ s3Client: b.s3Client, dynClient: b.dynClient, diff --git a/backend/remote-state/s3/backend_test.go b/backend/remote-state/s3/backend_test.go index 44987683ff..d90c76a6c6 100644 --- a/backend/remote-state/s3/backend_test.go +++ b/backend/remote-state/s3/backend_test.go @@ -3,6 +3,7 @@ package s3 import ( "fmt" "os" + "reflect" "testing" "time" @@ -10,6 +11,8 @@ import ( "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/s3" "github.com/hashicorp/terraform/backend" + "github.com/hashicorp/terraform/state/remote" + "github.com/hashicorp/terraform/terraform" ) // verify that we are doing ACC tests or the S3 tests specifically @@ -84,7 +87,7 @@ func TestBackendLocked(t *testing.T) { testACC(t) bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) - keyName := "testState" + keyName := "test/state" b1 := backend.TestBackendConfig(t, New(), map[string]interface{}{ "bucket": bucketName, @@ -108,6 +111,133 @@ func TestBackendLocked(t *testing.T) { backend.TestBackend(t, b1, b2) } +// add some extra junk in S3 to try and confuse the env listing. +func TestBackendExtraPaths(t *testing.T) { + testACC(t) + bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) + keyName := "test/state/tfstate" + + b := backend.TestBackendConfig(t, New(), map[string]interface{}{ + "bucket": bucketName, + "key": keyName, + "encrypt": true, + }).(*Backend) + + createS3Bucket(t, b.s3Client, bucketName) + defer deleteS3Bucket(t, b.s3Client, bucketName) + + // put multiple states in old env paths. + s1 := terraform.NewState() + s2 := terraform.NewState() + + // RemoteClient to Put things in various paths + client := &RemoteClient{ + s3Client: b.s3Client, + dynClient: b.dynClient, + bucketName: b.bucketName, + path: b.path("s1"), + serverSideEncryption: b.serverSideEncryption, + acl: b.acl, + kmsKeyID: b.kmsKeyID, + lockTable: b.lockTable, + } + + stateMgr := &remote.State{Client: client} + stateMgr.WriteState(s1) + if err := stateMgr.PersistState(); err != nil { + t.Fatal(err) + } + + client.path = b.path("s2") + stateMgr.WriteState(s2) + if err := stateMgr.PersistState(); err != nil { + t.Fatal(err) + } + + if err := checkStateList(b, []string{"default", "s1", "s2"}); err != nil { + t.Fatal(err) + } + + // put a state in an env directory name + client.path = keyEnvPrefix + "/error" + stateMgr.WriteState(terraform.NewState()) + if err := stateMgr.PersistState(); err != nil { + t.Fatal(err) + } + if err := checkStateList(b, []string{"default", "s1", "s2"}); err != nil { + t.Fatal(err) + } + + // add state with the wrong key for an existing env + client.path = keyEnvPrefix + "/s2/notTestState" + stateMgr.WriteState(terraform.NewState()) + if err := stateMgr.PersistState(); err != nil { + t.Fatal(err) + } + if err := checkStateList(b, []string{"default", "s1", "s2"}); err != nil { + t.Fatal(err) + } + + // remove the state with extra subkey + if err := b.DeleteState("s2"); err != nil { + t.Fatal(err) + } + + if err := checkStateList(b, []string{"default", "s1"}); err != nil { + t.Fatal(err) + } + + // fetch that state again, which should produce a new lineage + s2Mgr, err := b.State("s2") + if err != nil { + t.Fatal(err) + } + if err := s2Mgr.RefreshState(); err != nil { + t.Fatal(err) + } + + if s2Mgr.State().Lineage == s2.Lineage { + t.Fatal("state s2 was not deleted") + } + s2 = s2Mgr.State() + + // add a state with a key that matches an existing environment dir name + client.path = keyEnvPrefix + "/s2/" + stateMgr.WriteState(terraform.NewState()) + if err := stateMgr.PersistState(); err != nil { + t.Fatal(err) + } + + // make sure s2 is OK + s2Mgr, err = b.State("s2") + if err != nil { + t.Fatal(err) + } + if err := s2Mgr.RefreshState(); err != nil { + t.Fatal(err) + } + + if s2Mgr.State().Lineage != s2.Lineage { + t.Fatal("we got the wrong state for s2") + } + + if err := checkStateList(b, []string{"default", "s1", "s2"}); err != nil { + t.Fatal(err) + } +} + +func checkStateList(b backend.Backend, expected []string) error { + states, err := b.States() + if err != nil { + return err + } + + if !reflect.DeepEqual(states, expected) { + return fmt.Errorf("incorrect states listed: %q", states) + } + return nil +} + func createS3Bucket(t *testing.T, s3Client *s3.S3, bucketName string) { createBucketReq := &s3.CreateBucketInput{ Bucket: &bucketName,