From 73f5fbf4bc3164a60c95c03b4df80a7e96df0307 Mon Sep 17 00:00:00 2001 From: James Humphries Date: Mon, 18 Mar 2024 18:48:19 +0000 Subject: [PATCH] Added aws_kms key provider compliance tests (#1395) Signed-off-by: James Humphries Signed-off-by: Christian Mesh Co-authored-by: Christian Mesh --- .../keyprovider/aws_kms/compliance_test.go | 142 ++++++++++++++++++ .../encryption/keyprovider/aws_kms/config.go | 16 +- .../keyprovider/aws_kms/config_test.go | 41 +++++ .../keyprovider/aws_kms/mock_test.go | 50 ++++++ .../keyprovider/aws_kms/provider.go | 16 +- .../keyprovider/aws_kms/provider_test.go | 133 ++-------------- .../keyprovider/compliancetest/compliance.go | 4 +- 7 files changed, 274 insertions(+), 128 deletions(-) create mode 100644 internal/encryption/keyprovider/aws_kms/compliance_test.go create mode 100644 internal/encryption/keyprovider/aws_kms/mock_test.go diff --git a/internal/encryption/keyprovider/aws_kms/compliance_test.go b/internal/encryption/keyprovider/aws_kms/compliance_test.go new file mode 100644 index 0000000000..a281ab5e0b --- /dev/null +++ b/internal/encryption/keyprovider/aws_kms/compliance_test.go @@ -0,0 +1,142 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package aws_kms + +import ( + "fmt" + "testing" + + "github.com/opentofu/opentofu/internal/encryption/keyprovider/compliancetest" +) + +func TestKeyProvider(t *testing.T) { + testKeyId := getKey(t) + + if testKeyId == "" { + testKeyId = "alias/my-mock-key" + injectDefaultMock() + + t.Setenv("AWS_REGION", "us-east-1") + t.Setenv("AWS_ACCESS_KEY_ID", "accesskey") + t.Setenv("AWS_SECRET_ACCESS_KEY", "secretkey") + } + + compliancetest.ComplianceTest( + t, + compliancetest.TestConfiguration[*descriptor, *Config, *keyMeta, *keyProvider]{ + Descriptor: New().(*descriptor), + HCLParseTestCases: map[string]compliancetest.HCLParseTestCase[*Config, *keyProvider]{ + "success": { + HCL: fmt.Sprintf(`key_provider "aws_kms" "foo" { + kms_key_id = "%s" + key_spec = "AES_256" + skip_credentials_validation = true // required for mocking + }`, testKeyId), + ValidHCL: true, + ValidBuild: true, + Validate: func(config *Config, keyProvider *keyProvider) error { + if config.KMSKeyID != testKeyId { + return fmt.Errorf("incorrect key ID returned") + } + return nil + }, + }, + "empty": { + HCL: `key_provider "aws_kms" "foo" {}`, + ValidHCL: false, + ValidBuild: false, + }, + "invalid-key-spec": { + HCL: fmt.Sprintf(`key_provider "aws_kms" "foo" { + kms_key_id = "%s" + key_spec = "BROKEN STUFF" + }`, testKeyId), + ValidHCL: true, + ValidBuild: false, + }, + "empty-key-id": { + HCL: `key_provider "aws_kms" "foo" { + kms_key_id = "" + key_spec = "AES_256" + }`, + ValidHCL: true, + ValidBuild: false, + }, + "empty-key-spec": { + HCL: `key_provider "aws_kms" "foo" { + kms_key_id = "alias/temp" + key_spec = "" + }`, + ValidHCL: true, + ValidBuild: false, + }, + "unknown-property": { + HCL: fmt.Sprintf(`key_provider "aws_kms" "foo" { + kms_key_id = "%s" + key_spec = "AES_256" + unknown_property = "foo" + }`, testKeyId), + ValidHCL: false, + ValidBuild: false, + }, + }, + ConfigStructTestCases: map[string]compliancetest.ConfigStructTestCase[*Config, *keyProvider]{ + "success": { + Config: &Config{ + KMSKeyID: testKeyId, + KeySpec: "AES_256", + + SkipCredsValidation: true, // Required for mocking + }, + ValidBuild: true, + Validate: nil, + }, + "empty": { + Config: &Config{ + KMSKeyID: "", + KeySpec: "", + }, + ValidBuild: false, + Validate: nil, + }, + }, + MetadataStructTestCases: map[string]compliancetest.MetadataStructTestCase[*Config, *keyMeta]{ + "empty": { + ValidConfig: &Config{ + KMSKeyID: testKeyId, + KeySpec: "AES_256", + + SkipCredsValidation: true, // Required for mocking + }, + Meta: &keyMeta{}, + IsPresent: false, + IsValid: false, + }, + }, + ProvideTestCase: compliancetest.ProvideTestCase[*Config, *keyMeta]{ + ValidConfig: &Config{ + KMSKeyID: testKeyId, + KeySpec: "AES_256", + SkipCredsValidation: true, // Required for mocking + }, + ValidateKeys: func(dec []byte, enc []byte) error { + if len(dec) == 0 { + return fmt.Errorf("decryption key is empty") + } + if len(enc) == 0 { + return fmt.Errorf("encryption key is empty") + } + return nil + }, + ValidateMetadata: func(meta *keyMeta) error { + if meta.CiphertextBlob == nil || len(meta.CiphertextBlob) == 0 { + return fmt.Errorf("ciphertext blob is nil") + } + return nil + }, + }, + }) +} diff --git a/internal/encryption/keyprovider/aws_kms/config.go b/internal/encryption/keyprovider/aws_kms/config.go index 11ef68565d..33d79b72bd 100644 --- a/internal/encryption/keyprovider/aws_kms/config.go +++ b/internal/encryption/keyprovider/aws_kms/config.go @@ -17,6 +17,11 @@ import ( "github.com/opentofu/opentofu/version" ) +// Can be overridden for test mocking +var newKMSFromConfig func(aws.Config) kmsClient = func(cfg aws.Config) kmsClient { + return kms.NewFromConfig(cfg) +} + type Config struct { // KeyProvider Config KMSKeyID string `hcl:"kms_key_id"` @@ -193,7 +198,7 @@ func (c Config) Build() (keyprovider.KeyProvider, keyprovider.KeyMeta, error) { return &keyProvider{ Config: c, - svc: kms.NewFromConfig(awsConfig), + svc: newKMSFromConfig(awsConfig), ctx: ctx, }, new(keyMeta), nil } @@ -214,8 +219,11 @@ func (c Config) validate() (err error) { spec := c.getKeySpecAsAWSType() if spec == nil { + // This is to fetch a list of the values from the enum, because `spec` here can be nil, so we have to grab + // at least one of the enum possibilities here just to call .Values() + values := types.DataKeySpecAes256.Values() return &keyprovider.ErrInvalidConfiguration{ - Message: fmt.Sprintf("invalid key_spec %s, expected one of %v", c.KeySpec, spec.Values()), + Message: fmt.Sprintf("invalid key_spec %s, expected one of %v", c.KeySpec, values), } } @@ -228,10 +236,10 @@ func (c Config) getKeySpecAsAWSType() *types.DataKeySpec { var spec types.DataKeySpec for _, opt := range spec.Values() { if string(opt) == c.KeySpec { - spec = opt + return &opt } } - return &spec + return nil } // Mirrored from s3 backend config diff --git a/internal/encryption/keyprovider/aws_kms/config_test.go b/internal/encryption/keyprovider/aws_kms/config_test.go index 6d8be88261..b01c9b884a 100644 --- a/internal/encryption/keyprovider/aws_kms/config_test.go +++ b/internal/encryption/keyprovider/aws_kms/config_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/davecgh/go-spew/spew" awsbase "github.com/hashicorp/aws-sdk-go-base/v2" "github.com/hashicorp/hcl/v2" @@ -243,3 +244,43 @@ func TestValidate(t *testing.T) { }) } } + +func TestGetKeySpecAsAWSType(t *testing.T) { + + aes256 := types.DataKeySpecAes256 + aes128 := types.DataKeySpecAes128 + + cases := []struct { + key string + expected *types.DataKeySpec + }{ + { + key: "AES_256", + expected: &aes256, + }, + { + key: "AES_128", + expected: &aes128, + }, + { + key: "", + expected: nil, + }, + { + key: "invalidKey", + expected: nil, + }, + } + + for _, c := range cases { + t.Run(c.key, func(t *testing.T) { + config := Config{ + KeySpec: c.key, + } + actual := config.getKeySpecAsAWSType() + if !reflect.DeepEqual(c.expected, actual) { + t.Fatalf("Expected %s, got %s", spew.Sdump(c.expected), spew.Sdump(actual)) + } + }) + } +} diff --git a/internal/encryption/keyprovider/aws_kms/mock_test.go b/internal/encryption/keyprovider/aws_kms/mock_test.go new file mode 100644 index 0000000000..681fa13e0d --- /dev/null +++ b/internal/encryption/keyprovider/aws_kms/mock_test.go @@ -0,0 +1,50 @@ +package aws_kms + +import ( + "context" + "crypto/rand" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kms" +) + +type mockKMS struct { + genkey func(params *kms.GenerateDataKeyInput) (*kms.GenerateDataKeyOutput, error) + decrypt func(params *kms.DecryptInput) (*kms.DecryptOutput, error) +} + +func (m *mockKMS) GenerateDataKey(ctx context.Context, params *kms.GenerateDataKeyInput, optFns ...func(*kms.Options)) (*kms.GenerateDataKeyOutput, error) { + return m.genkey(params) +} +func (m *mockKMS) Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error) { + return m.decrypt(params) +} + +func injectMock(m *mockKMS) { + newKMSFromConfig = func(cfg aws.Config) kmsClient { + return m + } +} + +func injectDefaultMock() { + injectMock(&mockKMS{ + genkey: func(params *kms.GenerateDataKeyInput) (*kms.GenerateDataKeyOutput, error) { + keyData := make([]byte, 32) + _, err := rand.Read(keyData) + if err != nil { + panic(err) + } + + return &kms.GenerateDataKeyOutput{ + CiphertextBlob: append([]byte(*params.KeyId), keyData...), + Plaintext: keyData, + }, nil + + }, + decrypt: func(params *kms.DecryptInput) (*kms.DecryptOutput, error) { + return &kms.DecryptOutput{ + Plaintext: params.CiphertextBlob[:len(*params.KeyId)], + }, nil + }, + }) +} diff --git a/internal/encryption/keyprovider/aws_kms/provider.go b/internal/encryption/keyprovider/aws_kms/provider.go index 5967273aa4..2cb57cc2aa 100644 --- a/internal/encryption/keyprovider/aws_kms/provider.go +++ b/internal/encryption/keyprovider/aws_kms/provider.go @@ -17,17 +17,25 @@ func (m keyMeta) isPresent() bool { return len(m.CiphertextBlob) != 0 } +type kmsClient interface { + GenerateDataKey(ctx context.Context, params *kms.GenerateDataKeyInput, optFns ...func(*kms.Options)) (*kms.GenerateDataKeyOutput, error) + Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error) +} + type keyProvider struct { Config - svc *kms.Client + svc kmsClient ctx context.Context } func (p keyProvider) Provide(rawMeta keyprovider.KeyMeta) (keyprovider.Output, keyprovider.KeyMeta, error) { if rawMeta == nil { - return keyprovider.Output{}, nil, keyprovider.ErrInvalidMetadata{Message: "bug: no metadata struct provided"} + return keyprovider.Output{}, nil, &keyprovider.ErrInvalidMetadata{Message: "bug: no metadata struct provided"} + } + inMeta, ok := rawMeta.(*keyMeta) + if !ok { + return keyprovider.Output{}, nil, &keyprovider.ErrInvalidMetadata{Message: "bug: metadata struct is not of the correct type"} } - inMeta := rawMeta.(*keyMeta) outMeta := &keyMeta{} out := keyprovider.Output{} @@ -62,7 +70,7 @@ func (p keyProvider) Provide(rawMeta keyprovider.KeyMeta) (keyprovider.Output, k }) if decryptErr != nil { - return out, outMeta, decryptErr + return out, outMeta, &keyprovider.ErrKeyProviderFailure{Cause: decryptErr} } // Set decryption key on the output diff --git a/internal/encryption/keyprovider/aws_kms/provider_test.go b/internal/encryption/keyprovider/aws_kms/provider_test.go index 2fcc923462..28f6ccb275 100644 --- a/internal/encryption/keyprovider/aws_kms/provider_test.go +++ b/internal/encryption/keyprovider/aws_kms/provider_test.go @@ -1,65 +1,36 @@ package aws_kms import ( - "context" - "fmt" "os" "testing" - "time" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/kms" - "github.com/aws/aws-sdk-go-v2/service/kms/types" - awsbase "github.com/hashicorp/aws-sdk-go-base/v2" ) -// skipCheck checks if the test should be skipped or not based on environment variables -func skipCheck(t *testing.T) { - // check if TF_ACC and TF_KMS_TEST are unset - // if so, skip the test +func getKey(t *testing.T) string { if os.Getenv("TF_ACC") == "" && os.Getenv("TF_KMS_TEST") == "" { - t.Log("Skipping test because TF_ACC or TF_KMS_TEST is not set") - t.Skip() + return "" } + return os.Getenv("TF_AWS_KMS_KEY_ID") } -const testKeyPrefix = "tf-acc-test-kms-key" -const testAliasPrefix = "alias/my-key-alias" - func TestKMSProvider_Simple(t *testing.T) { - skipCheck(t) - ctx := context.TODO() + testKeyId := getKey(t) + if testKeyId == "" { + testKeyId = "alias/my-mock-key" + injectDefaultMock() - keyName := fmt.Sprintf("%s-%x", testKeyPrefix, time.Now().Unix()) - alias := fmt.Sprintf("%s-%x", testAliasPrefix, time.Now().Unix()) + t.Setenv("AWS_REGION", "us-east-1") + t.Setenv("AWS_ACCESS_KEY_ID", "accesskey") + t.Setenv("AWS_SECRET_ACCESS_KEY", "secretkey") + } - // Constructs a aws kms key provider config that accepts the alias as the key id + // Constructs a aws kms key provider config that accepts the key id providerConfig := Config{ - KMSKeyID: alias, + KMSKeyID: testKeyId, KeySpec: "AES_256", + + SkipCredsValidation: true, // Required for mocking } - // Mimic the creation of the aws client here via providerConfig.asAWSBase() so that - // we create a key in the same way that it will be read - awsBaseConfig, err := providerConfig.asAWSBase() - if err != nil { - t.Fatalf("Error creating AWS config: %s", err) - } - _, awsConfig, awsDiags := awsbase.GetAwsConfig(ctx, awsBaseConfig) - if awsDiags.HasError() { - t.Fatalf("Error creating AWS config: %v", awsDiags) - } - - kmsClient := kms.NewFromConfig(awsConfig) - - // Create the key - keyId := createKMSKey(ctx, t, kmsClient, keyName, awsBaseConfig.Region) - defer scheduleKMSKeyDeletion(ctx, t, kms.NewFromConfig(awsConfig), keyId) - - // Create an alias for the key - createAlias(ctx, t, kmsClient, keyId, &alias) - defer deleteAlias(ctx, t, kms.NewFromConfig(awsConfig), &alias) - // Now that we have the config, we can build the provider provider, metaIn, err := providerConfig.Build() if err != nil { @@ -104,77 +75,3 @@ func TestKMSProvider_Simple(t *testing.T) { t.Fatalf("No ciphertext blob provided") } } - -// createKMSKey creates a KMS key with the given name and region -func createKMSKey(ctx context.Context, t *testing.T, kmsClient *kms.Client, keyName string, region string) (keyID string) { - createKeyReq := kms.CreateKeyInput{ - Tags: []types.Tag{ - { - TagKey: aws.String("Name"), - TagValue: aws.String(keyName), - }, - }, - } - - t.Logf("Creating KMS key %s in %s", keyName, region) - - created, err := kmsClient.CreateKey(ctx, &createKeyReq) - if err != nil { - t.Fatalf("Error creating KMS key: %s", err) - } - - return *created.KeyMetadata.KeyId -} - -// createAlias creates a KMS alias for the given key -func createAlias(ctx context.Context, t *testing.T, kmsClient *kms.Client, keyID string, alias *string) { - if alias == nil { - return - } - - t.Logf("Creating KMS alias %s for key %s", *alias, keyID) - - aliasReq := kms.CreateAliasInput{ - AliasName: aws.String(*alias), - TargetKeyId: aws.String(keyID), - } - - _, err := kmsClient.CreateAlias(ctx, &aliasReq) - if err != nil { - t.Fatalf("Error creating KMS alias: %s", err) - } -} - -// scheduleKMSKeyDeletion schedules the deletion of a KMS key -// this attempts to delete it in the fastest possible way (7 days) -func scheduleKMSKeyDeletion(ctx context.Context, t *testing.T, kmsClient *kms.Client, keyID string) { - deleteKeyReq := kms.ScheduleKeyDeletionInput{ - KeyId: aws.String(keyID), - PendingWindowInDays: aws.Int32(7), - } - - t.Logf("Scheduling KMS key %s for deletion", keyID) - - _, err := kmsClient.ScheduleKeyDeletion(ctx, &deleteKeyReq) - if err != nil { - t.Fatalf("Error deleting KMS key: %s", err) - } -} - -// deleteAlias deletes a KMS alias -func deleteAlias(ctx context.Context, t *testing.T, kmsClient *kms.Client, alias *string) { - if alias == nil { - return - } - - t.Logf("Deleting KMS alias %s", *alias) - - deleteAliasReq := kms.DeleteAliasInput{ - AliasName: aws.String(*alias), - } - - _, err := kmsClient.DeleteAlias(ctx, &deleteAliasReq) - if err != nil { - t.Fatalf("Error deleting KMS alias: %s", err) - } -} diff --git a/internal/encryption/keyprovider/compliancetest/compliance.go b/internal/encryption/keyprovider/compliancetest/compliance.go index af1738bc09..ec2baed200 100644 --- a/internal/encryption/keyprovider/compliancetest/compliance.go +++ b/internal/encryption/keyprovider/compliancetest/compliance.go @@ -131,10 +131,10 @@ func ComplianceTest[TDescriptor keyprovider.Descriptor, TConfig keyprovider.Conf compliancetest.Fail(t, "Please provide at least one metadata test that represents non-present metadata.") } if !hasNotValid { - compliancetest.Fail(t, "Please provide at least one metadata test that represents an invalid metadata that is present.") + compliancetest.Log(t, "Warning: Please provide at least one metadata test that represents an invalid metadata that is present.") } if !hasValid { - compliancetest.Fail(t, "Please provide at least one metadata test that represents a valid metadata.") + compliancetest.Log(t, "Warning: Please provide at least one metadata test that represents a valid metadata.") } }) })