Added aws_kms key provider compliance tests (#1395)

Signed-off-by: James Humphries <james@james-humphries.co.uk>
Signed-off-by: Christian Mesh <christianmesh1@gmail.com>
Co-authored-by: Christian Mesh <christianmesh1@gmail.com>
This commit is contained in:
James Humphries 2024-03-18 18:48:19 +00:00 committed by GitHub
parent 21aa528090
commit 73f5fbf4bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 274 additions and 128 deletions

View File

@ -0,0 +1,142 @@
// Copyright (c) The OpenTofu Authors
// SPDX-License-Identifier: MPL-2.0
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package aws_kms
import (
"fmt"
"testing"
"github.com/opentofu/opentofu/internal/encryption/keyprovider/compliancetest"
)
func TestKeyProvider(t *testing.T) {
testKeyId := getKey(t)
if testKeyId == "" {
testKeyId = "alias/my-mock-key"
injectDefaultMock()
t.Setenv("AWS_REGION", "us-east-1")
t.Setenv("AWS_ACCESS_KEY_ID", "accesskey")
t.Setenv("AWS_SECRET_ACCESS_KEY", "secretkey")
}
compliancetest.ComplianceTest(
t,
compliancetest.TestConfiguration[*descriptor, *Config, *keyMeta, *keyProvider]{
Descriptor: New().(*descriptor),
HCLParseTestCases: map[string]compliancetest.HCLParseTestCase[*Config, *keyProvider]{
"success": {
HCL: fmt.Sprintf(`key_provider "aws_kms" "foo" {
kms_key_id = "%s"
key_spec = "AES_256"
skip_credentials_validation = true // required for mocking
}`, testKeyId),
ValidHCL: true,
ValidBuild: true,
Validate: func(config *Config, keyProvider *keyProvider) error {
if config.KMSKeyID != testKeyId {
return fmt.Errorf("incorrect key ID returned")
}
return nil
},
},
"empty": {
HCL: `key_provider "aws_kms" "foo" {}`,
ValidHCL: false,
ValidBuild: false,
},
"invalid-key-spec": {
HCL: fmt.Sprintf(`key_provider "aws_kms" "foo" {
kms_key_id = "%s"
key_spec = "BROKEN STUFF"
}`, testKeyId),
ValidHCL: true,
ValidBuild: false,
},
"empty-key-id": {
HCL: `key_provider "aws_kms" "foo" {
kms_key_id = ""
key_spec = "AES_256"
}`,
ValidHCL: true,
ValidBuild: false,
},
"empty-key-spec": {
HCL: `key_provider "aws_kms" "foo" {
kms_key_id = "alias/temp"
key_spec = ""
}`,
ValidHCL: true,
ValidBuild: false,
},
"unknown-property": {
HCL: fmt.Sprintf(`key_provider "aws_kms" "foo" {
kms_key_id = "%s"
key_spec = "AES_256"
unknown_property = "foo"
}`, testKeyId),
ValidHCL: false,
ValidBuild: false,
},
},
ConfigStructTestCases: map[string]compliancetest.ConfigStructTestCase[*Config, *keyProvider]{
"success": {
Config: &Config{
KMSKeyID: testKeyId,
KeySpec: "AES_256",
SkipCredsValidation: true, // Required for mocking
},
ValidBuild: true,
Validate: nil,
},
"empty": {
Config: &Config{
KMSKeyID: "",
KeySpec: "",
},
ValidBuild: false,
Validate: nil,
},
},
MetadataStructTestCases: map[string]compliancetest.MetadataStructTestCase[*Config, *keyMeta]{
"empty": {
ValidConfig: &Config{
KMSKeyID: testKeyId,
KeySpec: "AES_256",
SkipCredsValidation: true, // Required for mocking
},
Meta: &keyMeta{},
IsPresent: false,
IsValid: false,
},
},
ProvideTestCase: compliancetest.ProvideTestCase[*Config, *keyMeta]{
ValidConfig: &Config{
KMSKeyID: testKeyId,
KeySpec: "AES_256",
SkipCredsValidation: true, // Required for mocking
},
ValidateKeys: func(dec []byte, enc []byte) error {
if len(dec) == 0 {
return fmt.Errorf("decryption key is empty")
}
if len(enc) == 0 {
return fmt.Errorf("encryption key is empty")
}
return nil
},
ValidateMetadata: func(meta *keyMeta) error {
if meta.CiphertextBlob == nil || len(meta.CiphertextBlob) == 0 {
return fmt.Errorf("ciphertext blob is nil")
}
return nil
},
},
})
}

View File

@ -17,6 +17,11 @@ import (
"github.com/opentofu/opentofu/version"
)
// Can be overridden for test mocking
var newKMSFromConfig func(aws.Config) kmsClient = func(cfg aws.Config) kmsClient {
return kms.NewFromConfig(cfg)
}
type Config struct {
// KeyProvider Config
KMSKeyID string `hcl:"kms_key_id"`
@ -193,7 +198,7 @@ func (c Config) Build() (keyprovider.KeyProvider, keyprovider.KeyMeta, error) {
return &keyProvider{
Config: c,
svc: kms.NewFromConfig(awsConfig),
svc: newKMSFromConfig(awsConfig),
ctx: ctx,
}, new(keyMeta), nil
}
@ -214,8 +219,11 @@ func (c Config) validate() (err error) {
spec := c.getKeySpecAsAWSType()
if spec == nil {
// This is to fetch a list of the values from the enum, because `spec` here can be nil, so we have to grab
// at least one of the enum possibilities here just to call .Values()
values := types.DataKeySpecAes256.Values()
return &keyprovider.ErrInvalidConfiguration{
Message: fmt.Sprintf("invalid key_spec %s, expected one of %v", c.KeySpec, spec.Values()),
Message: fmt.Sprintf("invalid key_spec %s, expected one of %v", c.KeySpec, values),
}
}
@ -228,10 +236,10 @@ func (c Config) getKeySpecAsAWSType() *types.DataKeySpec {
var spec types.DataKeySpec
for _, opt := range spec.Values() {
if string(opt) == c.KeySpec {
spec = opt
return &opt
}
}
return &spec
return nil
}
// Mirrored from s3 backend config

View File

@ -8,6 +8,7 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/kms/types"
"github.com/davecgh/go-spew/spew"
awsbase "github.com/hashicorp/aws-sdk-go-base/v2"
"github.com/hashicorp/hcl/v2"
@ -243,3 +244,43 @@ func TestValidate(t *testing.T) {
})
}
}
func TestGetKeySpecAsAWSType(t *testing.T) {
aes256 := types.DataKeySpecAes256
aes128 := types.DataKeySpecAes128
cases := []struct {
key string
expected *types.DataKeySpec
}{
{
key: "AES_256",
expected: &aes256,
},
{
key: "AES_128",
expected: &aes128,
},
{
key: "",
expected: nil,
},
{
key: "invalidKey",
expected: nil,
},
}
for _, c := range cases {
t.Run(c.key, func(t *testing.T) {
config := Config{
KeySpec: c.key,
}
actual := config.getKeySpecAsAWSType()
if !reflect.DeepEqual(c.expected, actual) {
t.Fatalf("Expected %s, got %s", spew.Sdump(c.expected), spew.Sdump(actual))
}
})
}
}

View File

@ -0,0 +1,50 @@
package aws_kms
import (
"context"
"crypto/rand"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kms"
)
type mockKMS struct {
genkey func(params *kms.GenerateDataKeyInput) (*kms.GenerateDataKeyOutput, error)
decrypt func(params *kms.DecryptInput) (*kms.DecryptOutput, error)
}
func (m *mockKMS) GenerateDataKey(ctx context.Context, params *kms.GenerateDataKeyInput, optFns ...func(*kms.Options)) (*kms.GenerateDataKeyOutput, error) {
return m.genkey(params)
}
func (m *mockKMS) Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error) {
return m.decrypt(params)
}
func injectMock(m *mockKMS) {
newKMSFromConfig = func(cfg aws.Config) kmsClient {
return m
}
}
func injectDefaultMock() {
injectMock(&mockKMS{
genkey: func(params *kms.GenerateDataKeyInput) (*kms.GenerateDataKeyOutput, error) {
keyData := make([]byte, 32)
_, err := rand.Read(keyData)
if err != nil {
panic(err)
}
return &kms.GenerateDataKeyOutput{
CiphertextBlob: append([]byte(*params.KeyId), keyData...),
Plaintext: keyData,
}, nil
},
decrypt: func(params *kms.DecryptInput) (*kms.DecryptOutput, error) {
return &kms.DecryptOutput{
Plaintext: params.CiphertextBlob[:len(*params.KeyId)],
}, nil
},
})
}

View File

@ -17,17 +17,25 @@ func (m keyMeta) isPresent() bool {
return len(m.CiphertextBlob) != 0
}
type kmsClient interface {
GenerateDataKey(ctx context.Context, params *kms.GenerateDataKeyInput, optFns ...func(*kms.Options)) (*kms.GenerateDataKeyOutput, error)
Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error)
}
type keyProvider struct {
Config
svc *kms.Client
svc kmsClient
ctx context.Context
}
func (p keyProvider) Provide(rawMeta keyprovider.KeyMeta) (keyprovider.Output, keyprovider.KeyMeta, error) {
if rawMeta == nil {
return keyprovider.Output{}, nil, keyprovider.ErrInvalidMetadata{Message: "bug: no metadata struct provided"}
return keyprovider.Output{}, nil, &keyprovider.ErrInvalidMetadata{Message: "bug: no metadata struct provided"}
}
inMeta, ok := rawMeta.(*keyMeta)
if !ok {
return keyprovider.Output{}, nil, &keyprovider.ErrInvalidMetadata{Message: "bug: metadata struct is not of the correct type"}
}
inMeta := rawMeta.(*keyMeta)
outMeta := &keyMeta{}
out := keyprovider.Output{}
@ -62,7 +70,7 @@ func (p keyProvider) Provide(rawMeta keyprovider.KeyMeta) (keyprovider.Output, k
})
if decryptErr != nil {
return out, outMeta, decryptErr
return out, outMeta, &keyprovider.ErrKeyProviderFailure{Cause: decryptErr}
}
// Set decryption key on the output

View File

@ -1,65 +1,36 @@
package aws_kms
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/aws/aws-sdk-go-v2/service/kms/types"
awsbase "github.com/hashicorp/aws-sdk-go-base/v2"
)
// skipCheck checks if the test should be skipped or not based on environment variables
func skipCheck(t *testing.T) {
// check if TF_ACC and TF_KMS_TEST are unset
// if so, skip the test
func getKey(t *testing.T) string {
if os.Getenv("TF_ACC") == "" && os.Getenv("TF_KMS_TEST") == "" {
t.Log("Skipping test because TF_ACC or TF_KMS_TEST is not set")
t.Skip()
return ""
}
return os.Getenv("TF_AWS_KMS_KEY_ID")
}
const testKeyPrefix = "tf-acc-test-kms-key"
const testAliasPrefix = "alias/my-key-alias"
func TestKMSProvider_Simple(t *testing.T) {
skipCheck(t)
ctx := context.TODO()
testKeyId := getKey(t)
if testKeyId == "" {
testKeyId = "alias/my-mock-key"
injectDefaultMock()
keyName := fmt.Sprintf("%s-%x", testKeyPrefix, time.Now().Unix())
alias := fmt.Sprintf("%s-%x", testAliasPrefix, time.Now().Unix())
t.Setenv("AWS_REGION", "us-east-1")
t.Setenv("AWS_ACCESS_KEY_ID", "accesskey")
t.Setenv("AWS_SECRET_ACCESS_KEY", "secretkey")
}
// Constructs a aws kms key provider config that accepts the alias as the key id
// Constructs a aws kms key provider config that accepts the key id
providerConfig := Config{
KMSKeyID: alias,
KMSKeyID: testKeyId,
KeySpec: "AES_256",
SkipCredsValidation: true, // Required for mocking
}
// Mimic the creation of the aws client here via providerConfig.asAWSBase() so that
// we create a key in the same way that it will be read
awsBaseConfig, err := providerConfig.asAWSBase()
if err != nil {
t.Fatalf("Error creating AWS config: %s", err)
}
_, awsConfig, awsDiags := awsbase.GetAwsConfig(ctx, awsBaseConfig)
if awsDiags.HasError() {
t.Fatalf("Error creating AWS config: %v", awsDiags)
}
kmsClient := kms.NewFromConfig(awsConfig)
// Create the key
keyId := createKMSKey(ctx, t, kmsClient, keyName, awsBaseConfig.Region)
defer scheduleKMSKeyDeletion(ctx, t, kms.NewFromConfig(awsConfig), keyId)
// Create an alias for the key
createAlias(ctx, t, kmsClient, keyId, &alias)
defer deleteAlias(ctx, t, kms.NewFromConfig(awsConfig), &alias)
// Now that we have the config, we can build the provider
provider, metaIn, err := providerConfig.Build()
if err != nil {
@ -104,77 +75,3 @@ func TestKMSProvider_Simple(t *testing.T) {
t.Fatalf("No ciphertext blob provided")
}
}
// createKMSKey creates a KMS key with the given name and region
func createKMSKey(ctx context.Context, t *testing.T, kmsClient *kms.Client, keyName string, region string) (keyID string) {
createKeyReq := kms.CreateKeyInput{
Tags: []types.Tag{
{
TagKey: aws.String("Name"),
TagValue: aws.String(keyName),
},
},
}
t.Logf("Creating KMS key %s in %s", keyName, region)
created, err := kmsClient.CreateKey(ctx, &createKeyReq)
if err != nil {
t.Fatalf("Error creating KMS key: %s", err)
}
return *created.KeyMetadata.KeyId
}
// createAlias creates a KMS alias for the given key
func createAlias(ctx context.Context, t *testing.T, kmsClient *kms.Client, keyID string, alias *string) {
if alias == nil {
return
}
t.Logf("Creating KMS alias %s for key %s", *alias, keyID)
aliasReq := kms.CreateAliasInput{
AliasName: aws.String(*alias),
TargetKeyId: aws.String(keyID),
}
_, err := kmsClient.CreateAlias(ctx, &aliasReq)
if err != nil {
t.Fatalf("Error creating KMS alias: %s", err)
}
}
// scheduleKMSKeyDeletion schedules the deletion of a KMS key
// this attempts to delete it in the fastest possible way (7 days)
func scheduleKMSKeyDeletion(ctx context.Context, t *testing.T, kmsClient *kms.Client, keyID string) {
deleteKeyReq := kms.ScheduleKeyDeletionInput{
KeyId: aws.String(keyID),
PendingWindowInDays: aws.Int32(7),
}
t.Logf("Scheduling KMS key %s for deletion", keyID)
_, err := kmsClient.ScheduleKeyDeletion(ctx, &deleteKeyReq)
if err != nil {
t.Fatalf("Error deleting KMS key: %s", err)
}
}
// deleteAlias deletes a KMS alias
func deleteAlias(ctx context.Context, t *testing.T, kmsClient *kms.Client, alias *string) {
if alias == nil {
return
}
t.Logf("Deleting KMS alias %s", *alias)
deleteAliasReq := kms.DeleteAliasInput{
AliasName: aws.String(*alias),
}
_, err := kmsClient.DeleteAlias(ctx, &deleteAliasReq)
if err != nil {
t.Fatalf("Error deleting KMS alias: %s", err)
}
}

View File

@ -131,10 +131,10 @@ func ComplianceTest[TDescriptor keyprovider.Descriptor, TConfig keyprovider.Conf
compliancetest.Fail(t, "Please provide at least one metadata test that represents non-present metadata.")
}
if !hasNotValid {
compliancetest.Fail(t, "Please provide at least one metadata test that represents an invalid metadata that is present.")
compliancetest.Log(t, "Warning: Please provide at least one metadata test that represents an invalid metadata that is present.")
}
if !hasValid {
compliancetest.Fail(t, "Please provide at least one metadata test that represents a valid metadata.")
compliancetest.Log(t, "Warning: Please provide at least one metadata test that represents a valid metadata.")
}
})
})