mirror of
https://github.com/opentofu/opentofu.git
synced 2025-02-25 18:45:20 -06:00
Added aws_kms key provider compliance tests (#1395)
Signed-off-by: James Humphries <james@james-humphries.co.uk> Signed-off-by: Christian Mesh <christianmesh1@gmail.com> Co-authored-by: Christian Mesh <christianmesh1@gmail.com>
This commit is contained in:
parent
21aa528090
commit
73f5fbf4bc
142
internal/encryption/keyprovider/aws_kms/compliance_test.go
Normal file
142
internal/encryption/keyprovider/aws_kms/compliance_test.go
Normal file
@ -0,0 +1,142 @@
|
||||
// Copyright (c) The OpenTofu Authors
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
// Copyright (c) 2023 HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
package aws_kms
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/opentofu/opentofu/internal/encryption/keyprovider/compliancetest"
|
||||
)
|
||||
|
||||
func TestKeyProvider(t *testing.T) {
|
||||
testKeyId := getKey(t)
|
||||
|
||||
if testKeyId == "" {
|
||||
testKeyId = "alias/my-mock-key"
|
||||
injectDefaultMock()
|
||||
|
||||
t.Setenv("AWS_REGION", "us-east-1")
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "accesskey")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "secretkey")
|
||||
}
|
||||
|
||||
compliancetest.ComplianceTest(
|
||||
t,
|
||||
compliancetest.TestConfiguration[*descriptor, *Config, *keyMeta, *keyProvider]{
|
||||
Descriptor: New().(*descriptor),
|
||||
HCLParseTestCases: map[string]compliancetest.HCLParseTestCase[*Config, *keyProvider]{
|
||||
"success": {
|
||||
HCL: fmt.Sprintf(`key_provider "aws_kms" "foo" {
|
||||
kms_key_id = "%s"
|
||||
key_spec = "AES_256"
|
||||
skip_credentials_validation = true // required for mocking
|
||||
}`, testKeyId),
|
||||
ValidHCL: true,
|
||||
ValidBuild: true,
|
||||
Validate: func(config *Config, keyProvider *keyProvider) error {
|
||||
if config.KMSKeyID != testKeyId {
|
||||
return fmt.Errorf("incorrect key ID returned")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
"empty": {
|
||||
HCL: `key_provider "aws_kms" "foo" {}`,
|
||||
ValidHCL: false,
|
||||
ValidBuild: false,
|
||||
},
|
||||
"invalid-key-spec": {
|
||||
HCL: fmt.Sprintf(`key_provider "aws_kms" "foo" {
|
||||
kms_key_id = "%s"
|
||||
key_spec = "BROKEN STUFF"
|
||||
}`, testKeyId),
|
||||
ValidHCL: true,
|
||||
ValidBuild: false,
|
||||
},
|
||||
"empty-key-id": {
|
||||
HCL: `key_provider "aws_kms" "foo" {
|
||||
kms_key_id = ""
|
||||
key_spec = "AES_256"
|
||||
}`,
|
||||
ValidHCL: true,
|
||||
ValidBuild: false,
|
||||
},
|
||||
"empty-key-spec": {
|
||||
HCL: `key_provider "aws_kms" "foo" {
|
||||
kms_key_id = "alias/temp"
|
||||
key_spec = ""
|
||||
}`,
|
||||
ValidHCL: true,
|
||||
ValidBuild: false,
|
||||
},
|
||||
"unknown-property": {
|
||||
HCL: fmt.Sprintf(`key_provider "aws_kms" "foo" {
|
||||
kms_key_id = "%s"
|
||||
key_spec = "AES_256"
|
||||
unknown_property = "foo"
|
||||
}`, testKeyId),
|
||||
ValidHCL: false,
|
||||
ValidBuild: false,
|
||||
},
|
||||
},
|
||||
ConfigStructTestCases: map[string]compliancetest.ConfigStructTestCase[*Config, *keyProvider]{
|
||||
"success": {
|
||||
Config: &Config{
|
||||
KMSKeyID: testKeyId,
|
||||
KeySpec: "AES_256",
|
||||
|
||||
SkipCredsValidation: true, // Required for mocking
|
||||
},
|
||||
ValidBuild: true,
|
||||
Validate: nil,
|
||||
},
|
||||
"empty": {
|
||||
Config: &Config{
|
||||
KMSKeyID: "",
|
||||
KeySpec: "",
|
||||
},
|
||||
ValidBuild: false,
|
||||
Validate: nil,
|
||||
},
|
||||
},
|
||||
MetadataStructTestCases: map[string]compliancetest.MetadataStructTestCase[*Config, *keyMeta]{
|
||||
"empty": {
|
||||
ValidConfig: &Config{
|
||||
KMSKeyID: testKeyId,
|
||||
KeySpec: "AES_256",
|
||||
|
||||
SkipCredsValidation: true, // Required for mocking
|
||||
},
|
||||
Meta: &keyMeta{},
|
||||
IsPresent: false,
|
||||
IsValid: false,
|
||||
},
|
||||
},
|
||||
ProvideTestCase: compliancetest.ProvideTestCase[*Config, *keyMeta]{
|
||||
ValidConfig: &Config{
|
||||
KMSKeyID: testKeyId,
|
||||
KeySpec: "AES_256",
|
||||
SkipCredsValidation: true, // Required for mocking
|
||||
},
|
||||
ValidateKeys: func(dec []byte, enc []byte) error {
|
||||
if len(dec) == 0 {
|
||||
return fmt.Errorf("decryption key is empty")
|
||||
}
|
||||
if len(enc) == 0 {
|
||||
return fmt.Errorf("encryption key is empty")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ValidateMetadata: func(meta *keyMeta) error {
|
||||
if meta.CiphertextBlob == nil || len(meta.CiphertextBlob) == 0 {
|
||||
return fmt.Errorf("ciphertext blob is nil")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
@ -17,6 +17,11 @@ import (
|
||||
"github.com/opentofu/opentofu/version"
|
||||
)
|
||||
|
||||
// Can be overridden for test mocking
|
||||
var newKMSFromConfig func(aws.Config) kmsClient = func(cfg aws.Config) kmsClient {
|
||||
return kms.NewFromConfig(cfg)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// KeyProvider Config
|
||||
KMSKeyID string `hcl:"kms_key_id"`
|
||||
@ -193,7 +198,7 @@ func (c Config) Build() (keyprovider.KeyProvider, keyprovider.KeyMeta, error) {
|
||||
|
||||
return &keyProvider{
|
||||
Config: c,
|
||||
svc: kms.NewFromConfig(awsConfig),
|
||||
svc: newKMSFromConfig(awsConfig),
|
||||
ctx: ctx,
|
||||
}, new(keyMeta), nil
|
||||
}
|
||||
@ -214,8 +219,11 @@ func (c Config) validate() (err error) {
|
||||
|
||||
spec := c.getKeySpecAsAWSType()
|
||||
if spec == nil {
|
||||
// This is to fetch a list of the values from the enum, because `spec` here can be nil, so we have to grab
|
||||
// at least one of the enum possibilities here just to call .Values()
|
||||
values := types.DataKeySpecAes256.Values()
|
||||
return &keyprovider.ErrInvalidConfiguration{
|
||||
Message: fmt.Sprintf("invalid key_spec %s, expected one of %v", c.KeySpec, spec.Values()),
|
||||
Message: fmt.Sprintf("invalid key_spec %s, expected one of %v", c.KeySpec, values),
|
||||
}
|
||||
}
|
||||
|
||||
@ -228,10 +236,10 @@ func (c Config) getKeySpecAsAWSType() *types.DataKeySpec {
|
||||
var spec types.DataKeySpec
|
||||
for _, opt := range spec.Values() {
|
||||
if string(opt) == c.KeySpec {
|
||||
spec = opt
|
||||
return &opt
|
||||
}
|
||||
}
|
||||
return &spec
|
||||
return nil
|
||||
}
|
||||
|
||||
// Mirrored from s3 backend config
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
|
||||
"github.com/aws/aws-sdk-go-v2/service/kms/types"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
awsbase "github.com/hashicorp/aws-sdk-go-base/v2"
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
@ -243,3 +244,43 @@ func TestValidate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKeySpecAsAWSType(t *testing.T) {
|
||||
|
||||
aes256 := types.DataKeySpecAes256
|
||||
aes128 := types.DataKeySpecAes128
|
||||
|
||||
cases := []struct {
|
||||
key string
|
||||
expected *types.DataKeySpec
|
||||
}{
|
||||
{
|
||||
key: "AES_256",
|
||||
expected: &aes256,
|
||||
},
|
||||
{
|
||||
key: "AES_128",
|
||||
expected: &aes128,
|
||||
},
|
||||
{
|
||||
key: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
key: "invalidKey",
|
||||
expected: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.key, func(t *testing.T) {
|
||||
config := Config{
|
||||
KeySpec: c.key,
|
||||
}
|
||||
actual := config.getKeySpecAsAWSType()
|
||||
if !reflect.DeepEqual(c.expected, actual) {
|
||||
t.Fatalf("Expected %s, got %s", spew.Sdump(c.expected), spew.Sdump(actual))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
50
internal/encryption/keyprovider/aws_kms/mock_test.go
Normal file
50
internal/encryption/keyprovider/aws_kms/mock_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
package aws_kms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/kms"
|
||||
)
|
||||
|
||||
type mockKMS struct {
|
||||
genkey func(params *kms.GenerateDataKeyInput) (*kms.GenerateDataKeyOutput, error)
|
||||
decrypt func(params *kms.DecryptInput) (*kms.DecryptOutput, error)
|
||||
}
|
||||
|
||||
func (m *mockKMS) GenerateDataKey(ctx context.Context, params *kms.GenerateDataKeyInput, optFns ...func(*kms.Options)) (*kms.GenerateDataKeyOutput, error) {
|
||||
return m.genkey(params)
|
||||
}
|
||||
func (m *mockKMS) Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error) {
|
||||
return m.decrypt(params)
|
||||
}
|
||||
|
||||
func injectMock(m *mockKMS) {
|
||||
newKMSFromConfig = func(cfg aws.Config) kmsClient {
|
||||
return m
|
||||
}
|
||||
}
|
||||
|
||||
func injectDefaultMock() {
|
||||
injectMock(&mockKMS{
|
||||
genkey: func(params *kms.GenerateDataKeyInput) (*kms.GenerateDataKeyOutput, error) {
|
||||
keyData := make([]byte, 32)
|
||||
_, err := rand.Read(keyData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &kms.GenerateDataKeyOutput{
|
||||
CiphertextBlob: append([]byte(*params.KeyId), keyData...),
|
||||
Plaintext: keyData,
|
||||
}, nil
|
||||
|
||||
},
|
||||
decrypt: func(params *kms.DecryptInput) (*kms.DecryptOutput, error) {
|
||||
return &kms.DecryptOutput{
|
||||
Plaintext: params.CiphertextBlob[:len(*params.KeyId)],
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
}
|
@ -17,17 +17,25 @@ func (m keyMeta) isPresent() bool {
|
||||
return len(m.CiphertextBlob) != 0
|
||||
}
|
||||
|
||||
type kmsClient interface {
|
||||
GenerateDataKey(ctx context.Context, params *kms.GenerateDataKeyInput, optFns ...func(*kms.Options)) (*kms.GenerateDataKeyOutput, error)
|
||||
Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error)
|
||||
}
|
||||
|
||||
type keyProvider struct {
|
||||
Config
|
||||
svc *kms.Client
|
||||
svc kmsClient
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (p keyProvider) Provide(rawMeta keyprovider.KeyMeta) (keyprovider.Output, keyprovider.KeyMeta, error) {
|
||||
if rawMeta == nil {
|
||||
return keyprovider.Output{}, nil, keyprovider.ErrInvalidMetadata{Message: "bug: no metadata struct provided"}
|
||||
return keyprovider.Output{}, nil, &keyprovider.ErrInvalidMetadata{Message: "bug: no metadata struct provided"}
|
||||
}
|
||||
inMeta, ok := rawMeta.(*keyMeta)
|
||||
if !ok {
|
||||
return keyprovider.Output{}, nil, &keyprovider.ErrInvalidMetadata{Message: "bug: metadata struct is not of the correct type"}
|
||||
}
|
||||
inMeta := rawMeta.(*keyMeta)
|
||||
|
||||
outMeta := &keyMeta{}
|
||||
out := keyprovider.Output{}
|
||||
@ -62,7 +70,7 @@ func (p keyProvider) Provide(rawMeta keyprovider.KeyMeta) (keyprovider.Output, k
|
||||
})
|
||||
|
||||
if decryptErr != nil {
|
||||
return out, outMeta, decryptErr
|
||||
return out, outMeta, &keyprovider.ErrKeyProviderFailure{Cause: decryptErr}
|
||||
}
|
||||
|
||||
// Set decryption key on the output
|
||||
|
@ -1,65 +1,36 @@
|
||||
package aws_kms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/kms"
|
||||
"github.com/aws/aws-sdk-go-v2/service/kms/types"
|
||||
awsbase "github.com/hashicorp/aws-sdk-go-base/v2"
|
||||
)
|
||||
|
||||
// skipCheck checks if the test should be skipped or not based on environment variables
|
||||
func skipCheck(t *testing.T) {
|
||||
// check if TF_ACC and TF_KMS_TEST are unset
|
||||
// if so, skip the test
|
||||
func getKey(t *testing.T) string {
|
||||
if os.Getenv("TF_ACC") == "" && os.Getenv("TF_KMS_TEST") == "" {
|
||||
t.Log("Skipping test because TF_ACC or TF_KMS_TEST is not set")
|
||||
t.Skip()
|
||||
return ""
|
||||
}
|
||||
return os.Getenv("TF_AWS_KMS_KEY_ID")
|
||||
}
|
||||
|
||||
const testKeyPrefix = "tf-acc-test-kms-key"
|
||||
const testAliasPrefix = "alias/my-key-alias"
|
||||
|
||||
func TestKMSProvider_Simple(t *testing.T) {
|
||||
skipCheck(t)
|
||||
ctx := context.TODO()
|
||||
testKeyId := getKey(t)
|
||||
if testKeyId == "" {
|
||||
testKeyId = "alias/my-mock-key"
|
||||
injectDefaultMock()
|
||||
|
||||
keyName := fmt.Sprintf("%s-%x", testKeyPrefix, time.Now().Unix())
|
||||
alias := fmt.Sprintf("%s-%x", testAliasPrefix, time.Now().Unix())
|
||||
t.Setenv("AWS_REGION", "us-east-1")
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "accesskey")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "secretkey")
|
||||
}
|
||||
|
||||
// Constructs a aws kms key provider config that accepts the alias as the key id
|
||||
// Constructs a aws kms key provider config that accepts the key id
|
||||
providerConfig := Config{
|
||||
KMSKeyID: alias,
|
||||
KMSKeyID: testKeyId,
|
||||
KeySpec: "AES_256",
|
||||
|
||||
SkipCredsValidation: true, // Required for mocking
|
||||
}
|
||||
|
||||
// Mimic the creation of the aws client here via providerConfig.asAWSBase() so that
|
||||
// we create a key in the same way that it will be read
|
||||
awsBaseConfig, err := providerConfig.asAWSBase()
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating AWS config: %s", err)
|
||||
}
|
||||
_, awsConfig, awsDiags := awsbase.GetAwsConfig(ctx, awsBaseConfig)
|
||||
if awsDiags.HasError() {
|
||||
t.Fatalf("Error creating AWS config: %v", awsDiags)
|
||||
}
|
||||
|
||||
kmsClient := kms.NewFromConfig(awsConfig)
|
||||
|
||||
// Create the key
|
||||
keyId := createKMSKey(ctx, t, kmsClient, keyName, awsBaseConfig.Region)
|
||||
defer scheduleKMSKeyDeletion(ctx, t, kms.NewFromConfig(awsConfig), keyId)
|
||||
|
||||
// Create an alias for the key
|
||||
createAlias(ctx, t, kmsClient, keyId, &alias)
|
||||
defer deleteAlias(ctx, t, kms.NewFromConfig(awsConfig), &alias)
|
||||
|
||||
// Now that we have the config, we can build the provider
|
||||
provider, metaIn, err := providerConfig.Build()
|
||||
if err != nil {
|
||||
@ -104,77 +75,3 @@ func TestKMSProvider_Simple(t *testing.T) {
|
||||
t.Fatalf("No ciphertext blob provided")
|
||||
}
|
||||
}
|
||||
|
||||
// createKMSKey creates a KMS key with the given name and region
|
||||
func createKMSKey(ctx context.Context, t *testing.T, kmsClient *kms.Client, keyName string, region string) (keyID string) {
|
||||
createKeyReq := kms.CreateKeyInput{
|
||||
Tags: []types.Tag{
|
||||
{
|
||||
TagKey: aws.String("Name"),
|
||||
TagValue: aws.String(keyName),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Logf("Creating KMS key %s in %s", keyName, region)
|
||||
|
||||
created, err := kmsClient.CreateKey(ctx, &createKeyReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating KMS key: %s", err)
|
||||
}
|
||||
|
||||
return *created.KeyMetadata.KeyId
|
||||
}
|
||||
|
||||
// createAlias creates a KMS alias for the given key
|
||||
func createAlias(ctx context.Context, t *testing.T, kmsClient *kms.Client, keyID string, alias *string) {
|
||||
if alias == nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("Creating KMS alias %s for key %s", *alias, keyID)
|
||||
|
||||
aliasReq := kms.CreateAliasInput{
|
||||
AliasName: aws.String(*alias),
|
||||
TargetKeyId: aws.String(keyID),
|
||||
}
|
||||
|
||||
_, err := kmsClient.CreateAlias(ctx, &aliasReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating KMS alias: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// scheduleKMSKeyDeletion schedules the deletion of a KMS key
|
||||
// this attempts to delete it in the fastest possible way (7 days)
|
||||
func scheduleKMSKeyDeletion(ctx context.Context, t *testing.T, kmsClient *kms.Client, keyID string) {
|
||||
deleteKeyReq := kms.ScheduleKeyDeletionInput{
|
||||
KeyId: aws.String(keyID),
|
||||
PendingWindowInDays: aws.Int32(7),
|
||||
}
|
||||
|
||||
t.Logf("Scheduling KMS key %s for deletion", keyID)
|
||||
|
||||
_, err := kmsClient.ScheduleKeyDeletion(ctx, &deleteKeyReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Error deleting KMS key: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// deleteAlias deletes a KMS alias
|
||||
func deleteAlias(ctx context.Context, t *testing.T, kmsClient *kms.Client, alias *string) {
|
||||
if alias == nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("Deleting KMS alias %s", *alias)
|
||||
|
||||
deleteAliasReq := kms.DeleteAliasInput{
|
||||
AliasName: aws.String(*alias),
|
||||
}
|
||||
|
||||
_, err := kmsClient.DeleteAlias(ctx, &deleteAliasReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Error deleting KMS alias: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -131,10 +131,10 @@ func ComplianceTest[TDescriptor keyprovider.Descriptor, TConfig keyprovider.Conf
|
||||
compliancetest.Fail(t, "Please provide at least one metadata test that represents non-present metadata.")
|
||||
}
|
||||
if !hasNotValid {
|
||||
compliancetest.Fail(t, "Please provide at least one metadata test that represents an invalid metadata that is present.")
|
||||
compliancetest.Log(t, "Warning: Please provide at least one metadata test that represents an invalid metadata that is present.")
|
||||
}
|
||||
if !hasValid {
|
||||
compliancetest.Fail(t, "Please provide at least one metadata test that represents a valid metadata.")
|
||||
compliancetest.Log(t, "Warning: Please provide at least one metadata test that represents a valid metadata.")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user