diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index bc47c33af6..940553b50b 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -258,29 +258,6 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) } } - if val := obj.GetAttr("sse_customer_key"); !val.IsNull() { - s := val.AsString() - if len(s) != 44 { - diags = diags.Append(tfdiags.AttributeValue( - tfdiags.Error, - "Invalid sse_customer_key value", - "sse_customer_key must be 44 characters in length", - cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}}, - )) - } else { - var err error - _, err = base64.StdEncoding.DecodeString(s) - if err != nil { - diags = diags.Append(tfdiags.AttributeValue( - tfdiags.Error, - "Invalid sse_customer_key value", - fmt.Sprintf("sse_customer_key must be base64 encoded: %s", err), - cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}}, - )) - } - } - } - if val := obj.GetAttr("kms_key_id"); !val.IsNull() && val.AsString() != "" { if val := obj.GetAttr("sse_customer_key"); !val.IsNull() && val.AsString() != "" { diags = diags.Append(tfdiags.AttributeValue( @@ -337,9 +314,45 @@ func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics { b.kmsKeyID = stringAttr(obj, "kms_key_id") b.ddbTable = stringAttr(obj, "dynamodb_table") - if customerKeyString, ok := stringAttrOk(obj, "sse_customer_key"); ok { - // Validation is handled in PrepareConfig, so ignore it here - b.customerEncryptionKey, _ = base64.StdEncoding.DecodeString(customerKeyString) + // WarnOnEmptyString(), LenEquals(44), IsBase64Encoded() + if customerKey, ok := stringAttrOk(obj, "sse_customer_key"); ok { + if len(customerKey) != 44 { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid sse_customer_key value", + "sse_customer_key must be 44 characters in length", + cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}}, + )) + } else { + var err error + if b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKey); err != nil { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid sse_customer_key value", + fmt.Sprintf("sse_customer_key must be base64 encoded: %s", err), + cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}}, + )) + } + } + } else { + if customerKey := os.Getenv("AWS_SSE_CUSTOMER_KEY"); customerKey != "" { + if len(customerKey) != 44 { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Invalid AWS_SSE_CUSTOMER_KEY value", + "AWS_SSE_CUSTOMER_KEY must be 44 characters in length", + )) + } else { + var err error + if b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKey); err != nil { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Invalid AWS_SSE_CUSTOMER_KEY value", + fmt.Sprintf("AWS_SSE_CUSTOMER_KEY must be base64 encoded: %s", err), + )) + } + } + } } cfg := &awsbase.Config{ diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index d2d41d3ec6..ad57665ea8 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -4,6 +4,7 @@ package s3 import ( + "encoding/base64" "fmt" "net/url" "os" @@ -578,24 +579,6 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) { }), expectedErr: `workspace_key_prefix must not start or end with '/'`, }, - "sse_customer_key invalid length": { - config: cty.ObjectVal(map[string]cty.Value{ - "bucket": cty.StringVal("test"), - "key": cty.StringVal("test"), - "region": cty.StringVal("us-west-2"), - "sse_customer_key": cty.StringVal("key"), - }), - expectedErr: `sse_customer_key must be 44 characters in length`, - }, - "sse_customer_key invalid encoding": { - config: cty.ObjectVal(map[string]cty.Value{ - "bucket": cty.StringVal("test"), - "key": cty.StringVal("test"), - "region": cty.StringVal("us-west-2"), - "sse_customer_key": cty.StringVal("====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka"), - }), - expectedErr: `sse_customer_key must be base64 encoded`, - }, "encyrption key conflict": { config: cty.ObjectVal(map[string]cty.Value{ "bucket": cty.StringVal("test"), @@ -736,21 +719,130 @@ func TestBackendLocked(t *testing.T) { backend.TestBackendStateForceUnlock(t, b1, b2) } -func TestBackendSSECustomerKey(t *testing.T) { +func TestBackendSSECustomerKeyConfig(t *testing.T) { testACC(t) - bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) - b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{ - "bucket": bucketName, - "encrypt": true, - "key": "test-SSE-C", - "sse_customer_key": "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=", - })).(*Backend) + testCases := map[string]struct { + customerKey string + expectedErr string + }{ + "invalid length": { + customerKey: "test", + expectedErr: `sse_customer_key must be 44 characters in length`, + }, + "invalid encoding": { + customerKey: "====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka", + expectedErr: `sse_customer_key must be base64 encoded`, + }, + "valid": { + customerKey: "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=", + }, + } - createS3Bucket(t, b.s3Client, bucketName) - defer deleteS3Bucket(t, b.s3Client, bucketName) + for name, testCase := range testCases { + testCase := testCase - backend.TestBackendStates(t, b) + t.Run(name, func(t *testing.T) { + bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) + config := map[string]interface{}{ + "bucket": bucketName, + "encrypt": true, + "key": "test-SSE-C", + "sse_customer_key": testCase.customerKey, + } + + b := New().(*Backend) + diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(config))) + + if testCase.expectedErr != "" { + if diags.Err() != nil { + actualErr := diags.Err().Error() + if !strings.Contains(actualErr, testCase.expectedErr) { + t.Fatalf("unexpected validation result: %v", diags.Err()) + } + } else { + t.Fatal("expected an error, got none") + } + } else { + if diags.Err() != nil { + t.Fatalf("expected no error, got %s", diags.Err()) + } + if string(b.customerEncryptionKey) != string(must(base64.StdEncoding.DecodeString(testCase.customerKey))) { + t.Fatal("unexpected value for customer encryption key") + } + + createS3Bucket(t, b.s3Client, bucketName) + defer deleteS3Bucket(t, b.s3Client, bucketName) + + backend.TestBackendStates(t, b) + } + }) + } +} + +func TestBackendSSECustomerKeyEnvVar(t *testing.T) { + testACC(t) + + testCases := map[string]struct { + customerKey string + expectedErr string + }{ + "invalid length": { + customerKey: "test", + expectedErr: `AWS_SSE_CUSTOMER_KEY must be 44 characters in length`, + }, + "invalid encoding": { + customerKey: "====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka", + expectedErr: `AWS_SSE_CUSTOMER_KEY must be base64 encoded`, + }, + "valid": { + customerKey: "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=", + }, + } + + for name, testCase := range testCases { + testCase := testCase + + t.Run(name, func(t *testing.T) { + bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) + config := map[string]interface{}{ + "bucket": bucketName, + "encrypt": true, + "key": "test-SSE-C", + } + + os.Setenv("AWS_SSE_CUSTOMER_KEY", testCase.customerKey) + t.Cleanup(func() { + os.Unsetenv("AWS_SSE_CUSTOMER_KEY") + }) + + b := New().(*Backend) + diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(config))) + + if testCase.expectedErr != "" { + if diags.Err() != nil { + actualErr := diags.Err().Error() + if !strings.Contains(actualErr, testCase.expectedErr) { + t.Fatalf("unexpected validation result: %v", diags.Err()) + } + } else { + t.Fatal("expected an error, got none") + } + } else { + if diags.Err() != nil { + t.Fatalf("expected no error, got %s", diags.Err()) + } + if string(b.customerEncryptionKey) != string(must(base64.StdEncoding.DecodeString(testCase.customerKey))) { + t.Fatal("unexpected value for customer encryption key") + } + + createS3Bucket(t, b.s3Client, bucketName) + defer deleteS3Bucket(t, b.s3Client, bucketName) + + backend.TestBackendStates(t, b) + } + }) + } } // add some extra junk in S3 to try and confuse the env listing. @@ -1205,3 +1297,11 @@ func popEnv(env []string) { os.Setenv(k, v) } } + +func must[T any](v T, err error) T { + if err != nil { + panic(err) + } else { + return v + } +}