Combines sse_customer_key and AWS_SSE_CUSTOMER_KEY validation

This commit is contained in:
Graham Davison 2022-10-27 14:39:48 -07:00
parent 4eaa44c5a5
commit 827d7bd384
2 changed files with 168 additions and 55 deletions

View File

@ -258,29 +258,6 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics)
}
}
if val := obj.GetAttr("sse_customer_key"); !val.IsNull() {
s := val.AsString()
if len(s) != 44 {
diags = diags.Append(tfdiags.AttributeValue(
tfdiags.Error,
"Invalid sse_customer_key value",
"sse_customer_key must be 44 characters in length",
cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}},
))
} else {
var err error
_, err = base64.StdEncoding.DecodeString(s)
if err != nil {
diags = diags.Append(tfdiags.AttributeValue(
tfdiags.Error,
"Invalid sse_customer_key value",
fmt.Sprintf("sse_customer_key must be base64 encoded: %s", err),
cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}},
))
}
}
}
if val := obj.GetAttr("kms_key_id"); !val.IsNull() && val.AsString() != "" {
if val := obj.GetAttr("sse_customer_key"); !val.IsNull() && val.AsString() != "" {
diags = diags.Append(tfdiags.AttributeValue(
@ -337,9 +314,45 @@ func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics {
b.kmsKeyID = stringAttr(obj, "kms_key_id")
b.ddbTable = stringAttr(obj, "dynamodb_table")
if customerKeyString, ok := stringAttrOk(obj, "sse_customer_key"); ok {
// Validation is handled in PrepareConfig, so ignore it here
b.customerEncryptionKey, _ = base64.StdEncoding.DecodeString(customerKeyString)
// WarnOnEmptyString(), LenEquals(44), IsBase64Encoded()
if customerKey, ok := stringAttrOk(obj, "sse_customer_key"); ok {
if len(customerKey) != 44 {
diags = diags.Append(tfdiags.AttributeValue(
tfdiags.Error,
"Invalid sse_customer_key value",
"sse_customer_key must be 44 characters in length",
cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}},
))
} else {
var err error
if b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKey); err != nil {
diags = diags.Append(tfdiags.AttributeValue(
tfdiags.Error,
"Invalid sse_customer_key value",
fmt.Sprintf("sse_customer_key must be base64 encoded: %s", err),
cty.Path{cty.GetAttrStep{Name: "sse_customer_key"}},
))
}
}
} else {
if customerKey := os.Getenv("AWS_SSE_CUSTOMER_KEY"); customerKey != "" {
if len(customerKey) != 44 {
diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error,
"Invalid AWS_SSE_CUSTOMER_KEY value",
"AWS_SSE_CUSTOMER_KEY must be 44 characters in length",
))
} else {
var err error
if b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKey); err != nil {
diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error,
"Invalid AWS_SSE_CUSTOMER_KEY value",
fmt.Sprintf("AWS_SSE_CUSTOMER_KEY must be base64 encoded: %s", err),
))
}
}
}
}
cfg := &awsbase.Config{

View File

@ -4,6 +4,7 @@
package s3
import (
"encoding/base64"
"fmt"
"net/url"
"os"
@ -578,24 +579,6 @@ func TestBackendConfig_PrepareConfigValidation(t *testing.T) {
}),
expectedErr: `workspace_key_prefix must not start or end with '/'`,
},
"sse_customer_key invalid length": {
config: cty.ObjectVal(map[string]cty.Value{
"bucket": cty.StringVal("test"),
"key": cty.StringVal("test"),
"region": cty.StringVal("us-west-2"),
"sse_customer_key": cty.StringVal("key"),
}),
expectedErr: `sse_customer_key must be 44 characters in length`,
},
"sse_customer_key invalid encoding": {
config: cty.ObjectVal(map[string]cty.Value{
"bucket": cty.StringVal("test"),
"key": cty.StringVal("test"),
"region": cty.StringVal("us-west-2"),
"sse_customer_key": cty.StringVal("====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka"),
}),
expectedErr: `sse_customer_key must be base64 encoded`,
},
"encyrption key conflict": {
config: cty.ObjectVal(map[string]cty.Value{
"bucket": cty.StringVal("test"),
@ -736,21 +719,130 @@ func TestBackendLocked(t *testing.T) {
backend.TestBackendStateForceUnlock(t, b1, b2)
}
func TestBackendSSECustomerKey(t *testing.T) {
func TestBackendSSECustomerKeyConfig(t *testing.T) {
testACC(t)
bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"encrypt": true,
"key": "test-SSE-C",
"sse_customer_key": "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=",
})).(*Backend)
testCases := map[string]struct {
customerKey string
expectedErr string
}{
"invalid length": {
customerKey: "test",
expectedErr: `sse_customer_key must be 44 characters in length`,
},
"invalid encoding": {
customerKey: "====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka",
expectedErr: `sse_customer_key must be base64 encoded`,
},
"valid": {
customerKey: "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=",
},
}
createS3Bucket(t, b.s3Client, bucketName)
defer deleteS3Bucket(t, b.s3Client, bucketName)
for name, testCase := range testCases {
testCase := testCase
backend.TestBackendStates(t, b)
t.Run(name, func(t *testing.T) {
bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
config := map[string]interface{}{
"bucket": bucketName,
"encrypt": true,
"key": "test-SSE-C",
"sse_customer_key": testCase.customerKey,
}
b := New().(*Backend)
diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(config)))
if testCase.expectedErr != "" {
if diags.Err() != nil {
actualErr := diags.Err().Error()
if !strings.Contains(actualErr, testCase.expectedErr) {
t.Fatalf("unexpected validation result: %v", diags.Err())
}
} else {
t.Fatal("expected an error, got none")
}
} else {
if diags.Err() != nil {
t.Fatalf("expected no error, got %s", diags.Err())
}
if string(b.customerEncryptionKey) != string(must(base64.StdEncoding.DecodeString(testCase.customerKey))) {
t.Fatal("unexpected value for customer encryption key")
}
createS3Bucket(t, b.s3Client, bucketName)
defer deleteS3Bucket(t, b.s3Client, bucketName)
backend.TestBackendStates(t, b)
}
})
}
}
func TestBackendSSECustomerKeyEnvVar(t *testing.T) {
testACC(t)
testCases := map[string]struct {
customerKey string
expectedErr string
}{
"invalid length": {
customerKey: "test",
expectedErr: `AWS_SSE_CUSTOMER_KEY must be 44 characters in length`,
},
"invalid encoding": {
customerKey: "====CT70aTYB2JGff7AjQtwbiLkwH4npICay1PWtmdka",
expectedErr: `AWS_SSE_CUSTOMER_KEY must be base64 encoded`,
},
"valid": {
customerKey: "4Dm1n4rphuFgawxuzY/bEfvLf6rYK0gIjfaDSLlfXNk=",
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
config := map[string]interface{}{
"bucket": bucketName,
"encrypt": true,
"key": "test-SSE-C",
}
os.Setenv("AWS_SSE_CUSTOMER_KEY", testCase.customerKey)
t.Cleanup(func() {
os.Unsetenv("AWS_SSE_CUSTOMER_KEY")
})
b := New().(*Backend)
diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(config)))
if testCase.expectedErr != "" {
if diags.Err() != nil {
actualErr := diags.Err().Error()
if !strings.Contains(actualErr, testCase.expectedErr) {
t.Fatalf("unexpected validation result: %v", diags.Err())
}
} else {
t.Fatal("expected an error, got none")
}
} else {
if diags.Err() != nil {
t.Fatalf("expected no error, got %s", diags.Err())
}
if string(b.customerEncryptionKey) != string(must(base64.StdEncoding.DecodeString(testCase.customerKey))) {
t.Fatal("unexpected value for customer encryption key")
}
createS3Bucket(t, b.s3Client, bucketName)
defer deleteS3Bucket(t, b.s3Client, bucketName)
backend.TestBackendStates(t, b)
}
})
}
}
// add some extra junk in S3 to try and confuse the env listing.
@ -1205,3 +1297,11 @@ func popEnv(env []string) {
os.Setenv(k, v)
}
}
func must[T any](v T, err error) T {
if err != nil {
panic(err)
} else {
return v
}
}