From 8c99c7522925b9c9a6cebe8e22138fa39e9706ed Mon Sep 17 00:00:00 2001 From: Janos <86970079+janosdebugs@users.noreply.github.com> Date: Thu, 14 Mar 2024 15:53:40 +0100 Subject: [PATCH] [State Encryption] Compliance tests (#1377) Signed-off-by: Janos <86970079+janosdebugs@users.noreply.github.com> --- .../compliancetest/config_struct.go | 85 +++ internal/encryption/compliancetest/const.go | 12 + internal/encryption/compliancetest/log.go | 29 + internal/encryption/keyprovider/README.md | 4 + internal/encryption/keyprovider/addr.go | 5 - .../keyprovider/compliancetest/compliance.go | 501 ++++++++++++++++++ .../compliancetest/configuration.go | 78 +++ internal/encryption/keyprovider/id.go | 11 +- .../keyprovider/pbkdf2/compliance_test.go | 229 ++++++++ .../encryption/keyprovider/pbkdf2/metadata.go | 6 +- .../encryption/keyprovider/pbkdf2/provider.go | 9 +- .../keyprovider/pbkdf2/provider_test.go | 109 ---- .../encryption/keyprovider/static/config.go | 12 +- .../keyprovider/static/descriptor.go | 7 +- .../encryption/keyprovider/static/provider.go | 21 +- .../keyprovider/static/provider_test.go | 163 +++--- internal/encryption/keyprovider/validation.go | 13 + internal/encryption/method/README.md | 4 + internal/encryption/method/addr.go | 1 + internal/encryption/method/aesgcm/aesgcm.go | 3 + .../method/aesgcm/compliance_test.go | 197 +++++++ internal/encryption/method/aesgcm/config.go | 21 +- .../encryption/method/aesgcm/config_test.go | 13 - .../method/compliancetest/compliance.go | 321 +++++++++++ internal/encryption/method/errors.go | 8 + internal/encryption/method/id.go | 11 +- internal/encryption/methods.go | 2 +- .../registry/compliancetest/compliance.go | 29 + .../compliancetest/compliance_key_provider.go | 137 +++++ .../compliancetest/compliance_method.go | 138 +++++ .../{new.go => registry.go} | 2 +- .../registry_test.go | 17 + internal/encryption/registry/registry.go | 6 +- 33 files changed, 1983 insertions(+), 221 deletions(-) create mode 100644 internal/encryption/compliancetest/config_struct.go create mode 100644 internal/encryption/compliancetest/const.go create mode 100644 internal/encryption/compliancetest/log.go create mode 100644 internal/encryption/keyprovider/compliancetest/compliance.go create mode 100644 internal/encryption/keyprovider/compliancetest/configuration.go create mode 100644 internal/encryption/keyprovider/pbkdf2/compliance_test.go create mode 100644 internal/encryption/keyprovider/validation.go create mode 100644 internal/encryption/method/aesgcm/compliance_test.go create mode 100644 internal/encryption/method/compliancetest/compliance.go create mode 100644 internal/encryption/registry/compliancetest/compliance.go create mode 100644 internal/encryption/registry/compliancetest/compliance_key_provider.go create mode 100644 internal/encryption/registry/compliancetest/compliance_method.go rename internal/encryption/registry/lockingencryptionregistry/{new.go => registry.go} (95%) create mode 100644 internal/encryption/registry/lockingencryptionregistry/registry_test.go diff --git a/internal/encryption/compliancetest/config_struct.go b/internal/encryption/compliancetest/config_struct.go new file mode 100644 index 0000000000..300e9e4c4a --- /dev/null +++ b/internal/encryption/compliancetest/config_struct.go @@ -0,0 +1,85 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package compliancetest + +import ( + "reflect" + "strings" + "testing" +) + +func ConfigStruct[TConfig any](t *testing.T, configStruct any) { + Log(t, "Testing config struct compliance...") + if configStruct == nil { + Fail(t, "The ConfigStruct() method on the descriptor returns a nil configuration. Please implement this function correctly.") + } else { + Log(t, "The ConfigStruct() method returned a non-nil value.") + } + + configStructPtrType := reflect.TypeOf(configStruct) + if configStructPtrType.Kind() != reflect.Ptr { + Fail(t, "The ConfigStruct() method returns a %T, but it should return a pointer to a struct.", configStruct) + } else { + Log(t, "The ConfigStruct() method returned a pointer.") + } + configStructType := configStructPtrType.Elem() + if configStructType.Kind() != reflect.Struct { + Fail(t, "The ConfigStruct() method returns a pointer to %s, but it should return a pointer to a struct.", configStructType.Elem().Name()) + } else { + Log(t, "The ConfigStruct() method returned a pointer to a struct.") + } + + typedConfigStruct, ok := configStruct.(TConfig) + if !ok { + Fail(t, "The ConfigStruct() method returns a %T instead of a %T", configStruct, typedConfigStruct) + } else { + Log(t, "The ConfigStruct() method correctly returns a %T", typedConfigStruct) + } + + hclTagFound := false + for i := 0; i < configStructType.NumField(); i++ { + field := configStructType.Field(i) + hclTag, ok := field.Tag.Lookup("hcl") + if !ok { + continue + } + hclTagFound = true + if hclTag == "" { + Fail( + t, + "The field '%s' on the config structure %s has an empty HCL tag. Please remove the hcl tag or add a value that matches %s.", + field.Name, + configStructType.Name(), + hclTagRe, + ) + } else { + Log(t, "Found a non-empty hcl tag on field '%s' of %s.", field.Name, configStructType.Name()) + } + hclTagParts := strings.Split(hclTag, ",") + if !hclTagRe.MatchString(hclTagParts[0]) { + Fail( + t, + "The field '%s' on the config structure %s has an invalid hcl tag: %s. Please add a value that matches %s.", + field.Name, + configStructType.Name(), + hclTag, + hclTagRe, + ) + } else { + Log(t, "Found hcl tag on field '%s' of %s matches the name requirements.", field.Name, configStructType.Name()) + } + } + if !hclTagFound { + Fail( + t, + "The configuration struct %s does not contain any fields with hcl tags, which means users will not be able to configure this key provider. Please provide at least one field with an hcl tag.", + configStructType.Name(), + ) + } else { + Log(t, "Found at least one field with a hcl tag.") + } + +} diff --git a/internal/encryption/compliancetest/const.go b/internal/encryption/compliancetest/const.go new file mode 100644 index 0000000000..3248a0116d --- /dev/null +++ b/internal/encryption/compliancetest/const.go @@ -0,0 +1,12 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package compliancetest + +import ( + "regexp" +) + +var hclTagRe = regexp.MustCompile("^[a-zA-Z0-9_-]+$") diff --git a/internal/encryption/compliancetest/log.go b/internal/encryption/compliancetest/log.go new file mode 100644 index 0000000000..516e3d1a88 --- /dev/null +++ b/internal/encryption/compliancetest/log.go @@ -0,0 +1,29 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package compliancetest + +import ( + "fmt" + "testing" +) + +// Log writes a log line for a compliance test. +func Log(t *testing.T, msg string, params ...any) { + t.Helper() + t.Logf("\033[32m%s\033[0m", fmt.Sprintf(msg, params...)) +} + +// Fail fails a compliance test. +func Fail(t *testing.T, msg string, params ...any) { + t.Helper() + t.Fatalf("\033[31m%s\033[0m", fmt.Sprintf(msg, params...)) +} + +// Skip skips a compliance test. +func Skip(t *testing.T, msg string, params ...any) { + t.Helper() + t.Skipf("\033[33m%s\033[0m", fmt.Sprintf(msg, params...)) +} diff --git a/internal/encryption/keyprovider/README.md b/internal/encryption/keyprovider/README.md index d81d88bb13..002caa3794 100644 --- a/internal/encryption/keyprovider/README.md +++ b/internal/encryption/keyprovider/README.md @@ -24,6 +24,10 @@ Some key providers need to store data alongside the encrypted data, such as the When you implement a key provider, take a look at the [static](static) key provider as a template. You should never use this provider in production because it exposes users to certain weaknesses in some encryption methods, but it is a simple example for the structure. +### Testing your provider (do this first!) + +Before you even go about writing a key provider, please set up the compliance tests. You can create a single test case that calls `compliancetest.ComplianceTest`. This test suite will run your key provider through all important compliance tests and will make sure that you are not missing anything during the implementation. + ### Implementing the descriptor The descriptor is very simple, you need to implement the [`Descriptor`](descriptor.go) interface in a type. (It does not have to be a struct.) However, make sure that the `ConfigStruct` always returns a struct with `hcl` tags on it. For more information on the `hcl` tags, see the [gohcl documentation](https://godocs.io/github.com/hashicorp/hcl/v2/gohcl). diff --git a/internal/encryption/keyprovider/addr.go b/internal/encryption/keyprovider/addr.go index 422f9b7ab1..4dca01ea41 100644 --- a/internal/encryption/keyprovider/addr.go +++ b/internal/encryption/keyprovider/addr.go @@ -7,15 +7,10 @@ package keyprovider import ( "fmt" - "regexp" "github.com/hashicorp/hcl/v2" ) -// TODO is there a generalized way to regexp-check names? -var addrRe = regexp.MustCompile(`^key_provider\.([a-zA-Z_0-9-]+)\.([a-zA-Z_0-9-]+)$`) -var nameRe = regexp.MustCompile("^([a-zA-Z_0-9-]+)$") - // Addr is a type-alias for key provider address strings that identify a specific key provider configuration. // The Addr is an opaque value. Do not perform string manipulation on it outside the functions supplied by the // keyprovider package. diff --git a/internal/encryption/keyprovider/compliancetest/compliance.go b/internal/encryption/keyprovider/compliancetest/compliance.go new file mode 100644 index 0000000000..af1738bc09 --- /dev/null +++ b/internal/encryption/keyprovider/compliancetest/compliance.go @@ -0,0 +1,501 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package compliancetest + +import ( + "bytes" + "encoding/json" + "errors" + "reflect" + "testing" + + "github.com/hashicorp/hcl/v2/gohcl" + "github.com/opentofu/opentofu/internal/encryption/compliancetest" + "github.com/opentofu/opentofu/internal/encryption/config" + "github.com/opentofu/opentofu/internal/encryption/keyprovider" +) + +func ComplianceTest[TDescriptor keyprovider.Descriptor, TConfig keyprovider.Config, TMeta keyprovider.KeyMeta, TKeyProvider keyprovider.KeyProvider]( + t *testing.T, + config TestConfiguration[TDescriptor, TConfig, TMeta, TKeyProvider], +) { + var cfg TConfig + cfgType := reflect.TypeOf(cfg) + if cfgType.Kind() != reflect.Ptr || cfgType.Elem().Kind() != reflect.Struct { + compliancetest.Fail(t, "You declared the config type to be %T, but it should be a pointer to a struct. Please fix your call to ComplianceTest().", cfg) + } + + var meta TMeta + metaType := reflect.TypeOf(cfg) + if metaType.Kind() != reflect.Interface { + if metaType.Kind() != reflect.Ptr || metaType.Elem().Kind() != reflect.Struct { + compliancetest.Log(t, "You declared a metadata type as %T, but it should be a pointer to a struct. Please fix your call to ComplianceTest().", meta) + } + } else { + compliancetest.Log(t, "Metadata type declared as interface{}, assuming the key provider does not need metadata. (This will be validated later.)") + } + + t.Run("ID", func(t *testing.T) { + complianceTestID(t, config) + }) + + t.Run("ConfigStruct", func(t *testing.T) { + compliancetest.ConfigStruct[TConfig](t, config.Descriptor.ConfigStruct()) + + t.Run("hcl-parsing", func(t *testing.T) { + if config.HCLParseTestCases == nil { + compliancetest.Fail(t, "Please provide a map in HCLParseTestCases.") + } + for name, tc := range config.HCLParseTestCases { + tc := tc + t.Run(name, func(t *testing.T) { + complianceTestHCLParsingTestCase(t, tc, config) + }) + } + }) + + t.Run("config", func(t *testing.T) { + if config.ConfigStructTestCases == nil { + compliancetest.Fail(t, "Please provide a map in ConfigStructTestCases.") + } + for name, tc := range config.ConfigStructTestCases { + tc := tc + t.Run(name, func(t *testing.T) { + complianceTestConfigCase[TConfig, TKeyProvider, TMeta](t, tc) + }) + } + }) + }) + + t.Run("metadata", func(t *testing.T) { + if config.MetadataStructTestCases == nil { + compliancetest.Fail(t, "Please provide a map in MetadataStructTestCases.") + } + for name, tc := range config.MetadataStructTestCases { + tc := tc + t.Run(name, func(t *testing.T) { + complianceTestMetadataTestCase[TConfig, TKeyProvider, TMeta](t, tc) + }) + } + }) + + t.Run("provide", func(t *testing.T) { + complianceTestProvide[TDescriptor, TConfig, TKeyProvider, TMeta](t, config) + }) + + t.Run("test-completeness", func(t *testing.T) { + t.Run("HCL", func(t *testing.T) { + hasNotValidHCL := false + hasValidHCLNotValidBuild := false + hasValidHCLAndBuild := false + for _, tc := range config.HCLParseTestCases { + if !tc.ValidHCL { + hasNotValidHCL = true + } else { + if tc.ValidBuild { + hasValidHCLAndBuild = true + } else { + hasValidHCLNotValidBuild = true + } + } + } + if !hasNotValidHCL { + compliancetest.Fail(t, "Please define at least one test with an invalid HCL.") + } + if !hasValidHCLNotValidBuild { + compliancetest.Fail(t, "Please define at least one test with a valid HCL that will fail on Build() for validation.") + } + if !hasValidHCLAndBuild { + compliancetest.Fail(t, "Please define at least one test with a valid HCL that will succeed on Build() for validation.") + } + }) + t.Run("metadata", func(t *testing.T) { + hasNotPresent := false + hasNotValid := false + hasValid := false + for _, tc := range config.MetadataStructTestCases { + if !tc.IsPresent { + hasNotPresent = true + } else { + if tc.IsValid { + hasValid = true + } else { + hasNotValid = true + } + } + } + if !hasNotPresent { + compliancetest.Fail(t, "Please provide at least one metadata test that represents non-present metadata.") + } + if !hasNotValid { + compliancetest.Fail(t, "Please provide at least one metadata test that represents an invalid metadata that is present.") + } + if !hasValid { + compliancetest.Fail(t, "Please provide at least one metadata test that represents a valid metadata.") + } + }) + }) +} + +func complianceTestProvide[TDescriptor keyprovider.Descriptor, TConfig keyprovider.Config, TKeyProvider keyprovider.KeyProvider, TMeta keyprovider.KeyMeta]( + t *testing.T, + cfg TestConfiguration[TDescriptor, TConfig, TMeta, TKeyProvider], +) { + if reflect.ValueOf(cfg.ProvideTestCase.ValidConfig).IsNil() { + compliancetest.Fail(t, "Please provide a ValidConfig in ProvideTestCase.") + } + keyProviderConfig := cfg.ProvideTestCase.ValidConfig + t.Run("nil-metadata", func(t *testing.T) { + keyProvider, inMeta := complianceTestBuildConfigAndValidate[TKeyProvider, TMeta](t, keyProviderConfig, true) + + if reflect.ValueOf(inMeta).IsNil() { + compliancetest.Skip(t, "The key provider does not have metadata (no metadata returned from Build()).") + return + } + _, _, err := keyProvider.Provide(nil) + if err == nil { + compliancetest.Fail(t, "Provide() did not return no error when provided with nil metadata.") + } else { + compliancetest.Log(t, "Provide() correctly returned an error when provided with nil metadata (%v).", err) + } + var typedError *keyprovider.ErrInvalidMetadata + if !errors.As(err, &typedError) { + compliancetest.Fail(t, "Provide() returned an error of the type %T instead of %T. Please use the correct typed errors.", err, typedError) + } else { + compliancetest.Log(t, "Provide() correctly returned a %T when provided with nil metadata.", typedError) + } + }) + t.Run("incorrect-metadata-type", func(t *testing.T) { + keyProvider, inMeta := complianceTestBuildConfigAndValidate[TKeyProvider, TMeta](t, keyProviderConfig, true) + if reflect.ValueOf(inMeta).IsNil() { + compliancetest.Skip(t, "The key provider does not have metadata (no metadata returned from Build()).") + return + } + _, _, err := keyProvider.Provide(&struct{}{}) + if err == nil { + compliancetest.Fail(t, "Provide() did not return no error when provided with an incorrect metadata type.") + } else { + compliancetest.Log(t, "Provide() correctly returned an error when provided with an metadata type (%v).", err) + } + var typedError *keyprovider.ErrInvalidMetadata + if !errors.As(err, &typedError) { + compliancetest.Fail(t, "Provide() returned an error of the type %T instead of %T. Please use the correct typed errors.", err, typedError) + } else { + compliancetest.Log(t, "Provide() correctly returned a %T when provided with an incorrect metadata type.", typedError) + } + }) + t.Run("round-trip", func(t *testing.T) { + complianceTestRoundTrip(t, keyProviderConfig, cfg) + }) +} + +func complianceTestRoundTrip[TDescriptor keyprovider.Descriptor, TConfig keyprovider.Config, TKeyProvider keyprovider.KeyProvider, TMeta keyprovider.KeyMeta]( + t *testing.T, + keyProviderConfig TConfig, + cfg TestConfiguration[TDescriptor, TConfig, TMeta, TKeyProvider], +) { + keyProvider, inMeta := complianceTestBuildConfigAndValidate[TKeyProvider, TMeta](t, keyProviderConfig, true) + output, outMeta, err := keyProvider.Provide(inMeta) + if err != nil { + compliancetest.Fail(t, "Provide() failed (%v).", err) + } else { + compliancetest.Log(t, "Provide() succeeded.") + } + if cfg.ProvideTestCase.ValidateMetadata != nil { + if err := cfg.ProvideTestCase.ValidateMetadata(outMeta.(TMeta)); err != nil { + compliancetest.Fail(t, "The metadata after the second Provide() call failed the test (%v).", err) + } + } + + // Create a second key provider to avoid internal state. + keyProvider2, inMeta2 := complianceTestBuildConfigAndValidate[TKeyProvider, TMeta](t, keyProviderConfig, true) + + marshalledMeta, err := json.Marshal(outMeta) + if err != nil { + compliancetest.Fail(t, "JSON-marshalling output meta failed (%v).", err) + } else { + compliancetest.Log(t, "JSON-marshalling output meta succeeded: %s", marshalledMeta) + } + + if err := json.Unmarshal(marshalledMeta, &inMeta2); err != nil { + compliancetest.Fail(t, "JSON-unmarshalling meta failed (%v).", err) + } else { + compliancetest.Log(t, "JSON-unmarshalling meta succeeded.") + } + + output2, outMeta2, err := keyProvider2.Provide(inMeta2) + if err != nil { + compliancetest.Fail(t, "Provide() on the subsequent run failed (%v).", err) + } else { + compliancetest.Log(t, "Provide() on the subsequent run succeeded.") + } + + if cfg.ProvideTestCase.ExpectedOutput != nil { + if !bytes.Equal(cfg.ProvideTestCase.ExpectedOutput.EncryptionKey, output.EncryptionKey) { + compliancetest.Fail(t, "Incorrect encryption key received after the first Provide() call. Please set a break point to the line of this error message to debug this error.") + } + if !bytes.Equal(cfg.ProvideTestCase.ExpectedOutput.DecryptionKey, output2.DecryptionKey) { + compliancetest.Fail(t, "Incorrect decryption key received after the second Provide() call. Please set a break point to the line of this error message to debug this error.") + } + if !bytes.Equal(cfg.ProvideTestCase.ExpectedOutput.EncryptionKey, output2.EncryptionKey) { + compliancetest.Fail(t, "Incorrect encryption key received after the second Provide() call. Please set a break point to the line of this error message to debug this error.") + } + } + if cfg.ProvideTestCase.ValidateMetadata != nil { + if err := cfg.ProvideTestCase.ValidateMetadata(outMeta2.(TMeta)); err != nil { + compliancetest.Fail(t, "The metadata after the second Provide() call failed the test (%v).", err) + } + } + if cfg.ProvideTestCase.ValidateKeys == nil { + if !bytes.Equal(output2.DecryptionKey, output.EncryptionKey) { + compliancetest.Fail( + t, + "The encryption key from the first call to Provide() does not match the decryption key provided by the second Provide() call. If you intend the two keys to be different, please provide an ProvideTestCase.ValidateKeys function. If this is not intended, please set a break point to the line of this error message.", + ) + } else { + compliancetest.Log( + t, + "The encryption and decryption keys match.", + ) + } + } else { + if err := cfg.ProvideTestCase.ValidateKeys(output2.DecryptionKey, output.EncryptionKey); err != nil { + compliancetest.Fail( + t, + "The encryption key from the first call to Provide() does not match the decryption key provided by the second Provide() call (%v),", + err, + ) + } else { + compliancetest.Log( + t, + "The encryption and decryption keys match.", + ) + } + } +} + +func complianceTestID[TDescriptor keyprovider.Descriptor, TConfig keyprovider.Config, TMeta keyprovider.KeyMeta, TKeyProvider keyprovider.KeyProvider]( + t *testing.T, + config TestConfiguration[TDescriptor, TConfig, TMeta, TKeyProvider], +) { + id := config.Descriptor.ID() + if id == "" { + compliancetest.Fail(t, "ID is empty.") + } else { + compliancetest.Log(t, "ID is not empty.") + } + if err := id.Validate(); err != nil { + compliancetest.Fail(t, "ID failed validation: %s", id) + } else { + compliancetest.Log(t, "ID passed validation.") + } +} + +func complianceTestHCLParsingTestCase[TDescriptor keyprovider.Descriptor, TConfig keyprovider.Config, TMeta keyprovider.KeyMeta, TKeyProvider keyprovider.KeyProvider]( + t *testing.T, + tc HCLParseTestCase[TConfig, TKeyProvider], + cfg TestConfiguration[TDescriptor, TConfig, TMeta, TKeyProvider], +) { + parseError := false + parsedConfig, diags := config.LoadConfigFromString("config.hcl", tc.HCL) + if tc.ValidHCL { + if diags.HasErrors() { + compliancetest.Fail(t, "Unexpected HCL error (%v).", diags) + } else { + compliancetest.Log(t, "HCL successfully parsed.") + } + } else { + if diags.HasErrors() { + parseError = true + } + } + + configStruct := cfg.Descriptor.ConfigStruct() + diags = gohcl.DecodeBody( + parsedConfig.KeyProviderConfigs[0].Body, + nil, + configStruct, + ) + var keyProvider TKeyProvider + if tc.ValidHCL { + if diags.HasErrors() { + compliancetest.Fail(t, "Failed to parse empty HCL block into config struct (%v).", diags) + } else { + compliancetest.Log(t, "HCL successfully loaded into config struct.") + } + + keyProvider, _ = complianceTestBuildConfigAndValidate[TKeyProvider, TMeta](t, configStruct, tc.ValidBuild) + } else { + if !parseError && !diags.HasErrors() { + compliancetest.Fail(t, "Expected error during HCL parsing, but no error was returned.") + } else { + compliancetest.Log(t, "HCL loading errored correctly (%v).", diags) + } + } + + if tc.Validate != nil { + if err := tc.Validate(configStruct.(TConfig), keyProvider); err != nil { + compliancetest.Fail(t, "Error during validation and configuration (%v).", err) + } else { + compliancetest.Log(t, "Successfully validated parsed HCL config and applied modifications.") + } + } else { + compliancetest.Log(t, "No ValidateAndConfigure provided, skipping HCL parse validation.") + } +} + +func complianceTestConfigCase[TConfig keyprovider.Config, TKeyProvider keyprovider.KeyProvider, TMeta keyprovider.KeyMeta]( + t *testing.T, + tc ConfigStructTestCase[TConfig, TKeyProvider], +) { + keyProvider, _ := complianceTestBuildConfigAndValidate[TKeyProvider, TMeta](t, tc.Config, tc.ValidBuild) + if tc.Validate != nil { + if err := tc.Validate(keyProvider); err != nil { + compliancetest.Fail(t, "Error during validation and configuration (%v).", err) + } else { + compliancetest.Log(t, "Successfully validated parsed HCL config and applied modifications.") + } + } else { + compliancetest.Log(t, "No ValidateAndConfigure provided, skipping HCL parse validation.") + } +} + +func complianceTestBuildConfigAndValidate[TKeyProvider keyprovider.KeyProvider, TMeta keyprovider.KeyMeta]( + t *testing.T, + configStruct keyprovider.Config, + validBuild bool, +) (TKeyProvider, TMeta) { + if configStruct == nil { + compliancetest.Fail(t, "Nil struct passed!") + } + + var typedKeyProvider TKeyProvider + var typedMeta TMeta + var ok bool + kp, meta, err := configStruct.Build() + if validBuild { + if err != nil { + compliancetest.Fail(t, "Build() returned an unexpected error: %v.", err) + } else { + compliancetest.Log(t, "Build() did not return an error.") + } + typedKeyProvider, ok = kp.(TKeyProvider) + if !ok { + compliancetest.Fail(t, "Build() returned an invalid key provider type of %T, expected %T", kp, typedKeyProvider) + } else { + compliancetest.Log(t, "Build() returned the correct key provider type of %T.", typedKeyProvider) + } + + metaType := reflect.TypeOf(typedMeta) + if meta == nil { + if metaType.Kind() != reflect.Interface { + compliancetest.Fail(t, "Build() did not return a metadata, but you declared a metadata type. Please make sure that you always return the same metadata type.") + } else { + compliancetest.Log(t, "Build() did not return a metadata and the declared metadata type is interface{}.") + } + } else { + if metaType.Kind() == reflect.Interface { + compliancetest.Fail(t, "Build() returned metadata, but you declared an interface type as the metadata type. Please always declare a pointer to a struct as a metadata type.") + } else { + compliancetest.Log(t, "Build() returned metadata and the declared metadata type is not an interface.") + } + typedMeta, ok = meta.(TMeta) + if !ok { + compliancetest.Fail(t, "Build() returned an invalid metadata type of %T, expected %T", meta, typedMeta) + } else { + compliancetest.Log(t, "Build() returned the correct metadata type of %T.", meta) + } + } + } else { + if err == nil { + compliancetest.Fail(t, "Build() did not return an error.") + } else { + compliancetest.Log(t, "Build() correctly returned an error: %v", err) + } + + var typedError *keyprovider.ErrInvalidConfiguration + if !errors.As(err, &typedError) { + compliancetest.Fail( + t, + "Build() did not return the correct error type, got %T but expected %T", + err, + typedError, + ) + } else { + compliancetest.Log(t, "Build() returned the correct error type of %T", typedError) + } + } + return typedKeyProvider, typedMeta +} + +func complianceTestMetadataTestCase[TConfig keyprovider.Config, TKeyProvider keyprovider.KeyProvider, TMeta keyprovider.KeyMeta]( + t *testing.T, + tc MetadataStructTestCase[TConfig, TMeta], +) { + keyProvider, _ := complianceTestBuildConfigAndValidate[TKeyProvider, TMeta](t, tc.ValidConfig, true) + + output, _, err := keyProvider.Provide(tc.Meta) + if tc.IsPresent { + // This test case means that the input metadata should be considered present, so it's either an error or a + // decryption key. + if tc.IsValid { + if err != nil { + var typedError *keyprovider.ErrKeyProviderFailure + if !errors.As(err, &typedError) { + compliancetest.Fail( + t, + "The Provide() function returned an unexpected error, which was also of the incorrect type of %T instead of %T: %v", + err, + typedError, + err, + ) + } + compliancetest.Fail(t, "The Provide() function returned an unexpected error: %v", err) + } + } else { + if err == nil { + compliancetest.Fail(t, "The Provide() function did not return an error as expected.") + } else { + compliancetest.Log(t, "The Provide() function returned an expected error: %v", err) + } + + var typedError *keyprovider.ErrInvalidMetadata + if !errors.As(err, &typedError) { + compliancetest.Fail( + t, + "The Provide() function returned the error type of %T instead of %T. Please use the correct typed errors.", + err, + typedError, + ) + } + } + } else { + if err != nil { + var typedError *keyprovider.ErrKeyProviderFailure + if !errors.As(err, &typedError) { + compliancetest.Fail( + t, + "The Provide() function returned an unexpected error, which was also of the incorrect type of %T instead of %T: %v", + err, + typedError, + err, + ) + } + compliancetest.Fail(t, "The Provide() function returned an unexpected error: %v", err) + } + if len(output.DecryptionKey) != 0 { + compliancetest.Fail( + t, + "The Provide() function a decryption key despite not receiving input meta. This is incorrect, please don't return a decryption key unless you receive the input metadata.", + ) + } else { + compliancetest.Log( + t, + "The Provide() function correctly did not return a decryption key without input metadata.", + ) + } + } +} diff --git a/internal/encryption/keyprovider/compliancetest/configuration.go b/internal/encryption/keyprovider/compliancetest/configuration.go new file mode 100644 index 0000000000..e7839122a8 --- /dev/null +++ b/internal/encryption/keyprovider/compliancetest/configuration.go @@ -0,0 +1,78 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package compliancetest + +import ( + "github.com/opentofu/opentofu/internal/encryption/keyprovider" +) + +type TestConfiguration[TDescriptor keyprovider.Descriptor, TConfig keyprovider.Config, TMeta any, TKeyProvider keyprovider.KeyProvider] struct { + // Descriptor is the descriptor for the key provider. + Descriptor TDescriptor + + // HCLParseTestCases contains the test cases of parsing HCL configuration and then validating it using the Build() + // function. + HCLParseTestCases map[string]HCLParseTestCase[TConfig, TKeyProvider] + + // ConfigStructT validates that a certain config results or does not result in a valid Build() call. + ConfigStructTestCases map[string]ConfigStructTestCase[TConfig, TKeyProvider] + + // MetadataStructTestCases test various metadata values for correct handling. + MetadataStructTestCases map[string]MetadataStructTestCase[TConfig, TMeta] + + // ProvideTestCase exercises the entire chain and generates two keys. + ProvideTestCase ProvideTestCase[TConfig, TMeta] +} + +// HCLParseTestCase contains a test case that parses HCL into a configuration. +type HCLParseTestCase[TConfig keyprovider.Config, TKeyProvider keyprovider.KeyProvider] struct { + // HCL contains the code that should be parsed into the configuration structure. + HCL string + // ValidHCL indicates that the HCL block should be parsable into the configuration structure, but not necessarily + // result in a valid Build() call. + ValidHCL bool + // ValidBuild indicates that calling the Build() function should not result in an error. + ValidBuild bool + // Validate is an extra optional validation function that can check if the configuration contains the correct + // values parsed from HCL. If ValidBuild is true, the key provider will be passed as well. + Validate func(config TConfig, keyProvider TKeyProvider) error +} + +// ConfigStructTestCase validates that the config struct is behaving correctly when Build() is called. +type ConfigStructTestCase[TConfig keyprovider.Config, TKeyProvider keyprovider.KeyProvider] struct { + Config TConfig + ValidBuild bool + Validate func(keyProvider TKeyProvider) error +} + +// MetadataStructTestCase is a test case for metadata. +type MetadataStructTestCase[TConfig keyprovider.Config, TMeta any] struct { + // Config contains a valid configuration that should be used to construct the key provider. + ValidConfig TConfig + // Meta contains the metadata for this test case. + Meta TMeta + // IsPresent indicates that the supplied metadata in Meta should be treated as present and the Provide() function + // should either return an error or a decryption key. If IsPresent is false, the Provide() function must not + // return an error or a decryption key. + IsPresent bool + // IsValid indicates that, if IsPresent is true, the metadata should be valid and the Provide() function should not + // exit with a *keyprovider.ErrInvalidMetadata error. + IsValid bool +} + +// ProvideTestCase provides a test configuration Provide() test where a key is requested and then +// subsequently compared. +type ProvideTestCase[TConfig keyprovider.Config, TMeta any] struct { + // ValidConfig is a valid configuration that the integration test can use to generate keys. + ValidConfig TConfig + // ExpectedOutput indicates what keys are expected as an output when the integration test is ran with full metadata. + ExpectedOutput *keyprovider.Output + // ValidateKeys is a function that compares an encryption and a decryption key. The function should return an error + // if the two keys don't belong together. If you do not provide this function, bytes.Equal will be used. + ValidateKeys func(decryptionKey []byte, encryptionKey []byte) error + // ValidateMetadata is a function that validates that the resulting metadata is correct. + ValidateMetadata func(meta TMeta) error +} diff --git a/internal/encryption/keyprovider/id.go b/internal/encryption/keyprovider/id.go index 7928a395ee..1ce8e25bbe 100644 --- a/internal/encryption/keyprovider/id.go +++ b/internal/encryption/keyprovider/id.go @@ -5,11 +5,18 @@ package keyprovider +import "fmt" + // ID is a type alias to make passing the wrong ID into a key provider harder. type ID string // Validate validates the key provider ID for correctness. -func (i ID) Validate() error { - // TODO implement format checking +func (id ID) Validate() error { + if id == "" { + return fmt.Errorf("empty key provider ID (key provider IDs must match %s)", idRe.String()) + } + if !idRe.MatchString(string(id)) { + return fmt.Errorf("invalid key provider ID: %s (must match %s)", id, idRe.String()) + } return nil } diff --git a/internal/encryption/keyprovider/pbkdf2/compliance_test.go b/internal/encryption/keyprovider/pbkdf2/compliance_test.go new file mode 100644 index 0000000000..05ca3c8fb3 --- /dev/null +++ b/internal/encryption/keyprovider/pbkdf2/compliance_test.go @@ -0,0 +1,229 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pbkdf2 + +import ( + "crypto/rand" + "fmt" + "testing" + + "github.com/opentofu/opentofu/internal/encryption/keyprovider" + "github.com/opentofu/opentofu/internal/encryption/keyprovider/compliancetest" +) + +func TestCompliance(t *testing.T) { + validConfig := &Config{ + randomSource: rand.Reader, + Passphrase: "Hello world! 123", + KeyLength: DefaultKeyLength, + Iterations: DefaultIterations, + HashFunction: SHA256HashFunctionName, + SaltLength: DefaultSaltLength, + } + compliancetest.ComplianceTest( + t, + compliancetest.TestConfiguration[*descriptor, *Config, *Metadata, *pbkdf2KeyProvider]{ + Descriptor: New().(*descriptor), + HCLParseTestCases: map[string]compliancetest.HCLParseTestCase[*Config, *pbkdf2KeyProvider]{ + "empty": { + HCL: `key_provider "pbkdf2" "foo" { +}`, + ValidHCL: false, + ValidBuild: false, + Validate: nil, + }, + "basic": { + HCL: `key_provider "pbkdf2" "foo" { + passphrase = "Hello world! 123" +}`, + ValidHCL: true, + ValidBuild: true, + Validate: func(config *Config, keyProvider *pbkdf2KeyProvider) error { + if config.Passphrase != "Hello world! 123" { + return fmt.Errorf("invalid passphrase after HCL parsing") + } + if keyProvider.Passphrase != "Hello world! 123" { + return fmt.Errorf("invalid passphrase in key provideer") + } + return nil + }, + }, + "extended": { + HCL: fmt.Sprintf(`key_provider "pbkdf2" "foo" { + passphrase = "Hello world! 123" + key_length = %d + iterations = %d + salt_length = %d + hash_function = "%s" +}`, DefaultKeyLength+1, DefaultIterations+1, DefaultSaltLength+1, SHA256HashFunctionName), + ValidHCL: true, + ValidBuild: true, + Validate: func(config *Config, keyProvider *pbkdf2KeyProvider) error { + if config.KeyLength != DefaultKeyLength+1 { + return fmt.Errorf("incorrect key length after HCL parsing: %d", config.KeyLength) + } + if config.Iterations != DefaultIterations+1 { + return fmt.Errorf("incorrect iterations after HCL parsing: %d", config.Iterations) + } + if config.SaltLength != DefaultSaltLength+1 { + return fmt.Errorf("incorrect salt length after HCL parsing: %d", config.SaltLength) + } + if config.HashFunction != SHA256HashFunctionName { + return fmt.Errorf("incorrect hash function after HCL parsing: %s", config.HashFunction) + } + return nil + }, + }, + "short-passphrase": { + HCL: `key_provider "pbkdf2" "foo" { + passphrase = "Hello world! 12" +}`, + ValidHCL: true, + ValidBuild: false, + }, + "too-small-iterations": { + HCL: fmt.Sprintf(`key_provider "pbkdf2" "foo" { + passphrase = "Hello world! 123" + iterations = %d +}`, MinimumIterations-1), + ValidHCL: true, + ValidBuild: false, + }, + "invalid-hash-function": { + HCL: `key_provider "pbkdf2" "foo" { + passphrase = "Hello world! 123" + hash_function = "non_existent" +}`, + ValidHCL: true, + ValidBuild: false, + }, + }, + ConfigStructTestCases: map[string]compliancetest.ConfigStructTestCase[*Config, *pbkdf2KeyProvider]{}, + MetadataStructTestCases: map[string]compliancetest.MetadataStructTestCase[*Config, *Metadata]{ + "not-present-salt": { + ValidConfig: validConfig, + Meta: &Metadata{ + Salt: nil, + Iterations: DefaultIterations, + HashFunction: SHA256HashFunctionName, + KeyLength: 32, + }, + IsPresent: false, + }, + "not-present-iterations": { + ValidConfig: validConfig, + Meta: &Metadata{ + Salt: []byte("Hello world!"), + Iterations: 0, + HashFunction: SHA256HashFunctionName, + KeyLength: 32, + }, + IsPresent: false, + }, + "not-present-hash-func": { + ValidConfig: validConfig, + Meta: &Metadata{ + Salt: []byte("Hello world!"), + Iterations: DefaultIterations, + HashFunction: "", + KeyLength: 32, + }, + IsPresent: false, + }, + "not-present-key-length": { + ValidConfig: validConfig, + Meta: &Metadata{ + Salt: []byte("Hello world!"), + Iterations: DefaultIterations, + HashFunction: SHA256HashFunctionName, + KeyLength: 0, + }, + IsPresent: false, + }, + "present-valid": { + ValidConfig: validConfig, + Meta: &Metadata{ + Salt: []byte("Hello world!"), + Iterations: DefaultIterations, + HashFunction: SHA256HashFunctionName, + KeyLength: 32, + }, + IsPresent: true, + IsValid: true, + }, + "present-valid-too-few-iterations": { + ValidConfig: validConfig, + Meta: &Metadata{ + Salt: []byte("Hello world!"), + Iterations: MinimumIterations - 1, + HashFunction: SHA256HashFunctionName, + KeyLength: 32, + }, + IsPresent: true, + IsValid: true, + }, + "invalid-iterations": { + ValidConfig: validConfig, + Meta: &Metadata{ + Salt: []byte("Hello world!"), + Iterations: -1, + HashFunction: SHA256HashFunctionName, + KeyLength: 32, + }, + IsPresent: true, + IsValid: false, + }, + "invalid-salt-length": { + ValidConfig: validConfig, + Meta: &Metadata{ + Salt: []byte("Hello world!"), + Iterations: DefaultIterations, + HashFunction: SHA256HashFunctionName, + KeyLength: -1, + }, + IsPresent: true, + IsValid: false, + }, + }, + ProvideTestCase: compliancetest.ProvideTestCase[*Config, *Metadata]{ + ValidConfig: &Config{ + randomSource: &testRandomSource{t: t}, + Passphrase: "Hello world! 123", + KeyLength: DefaultKeyLength, + Iterations: DefaultIterations, + HashFunction: DefaultHashFunctionName, + SaltLength: DefaultSaltLength, + }, + ExpectedOutput: &keyprovider.Output{ + EncryptionKey: []byte{87, 192, 98, 53, 186, 42, 63, 139, 58, 118, 223, 169, 46, 84, 139, 29, 130, 59, 247, 106, 82, 61, 235, 144, 97, 131, 60, 229, 195, 109, 81, 111}, + DecryptionKey: []byte{87, 192, 98, 53, 186, 42, 63, 139, 58, 118, 223, 169, 46, 84, 139, 29, 130, 59, 247, 106, 82, 61, 235, 144, 97, 131, 60, 229, 195, 109, 81, 111}, + }, + ValidateKeys: nil, + ValidateMetadata: func(meta *Metadata) error { + if !meta.isPresent() { + return fmt.Errorf("output metadata is not present") + } + if err := meta.validate(); err != nil { + return err + } + if meta.KeyLength != DefaultKeyLength { + return fmt.Errorf("incorrect output metadata key length: %d", meta.KeyLength) + } + if meta.Iterations != DefaultIterations { + return fmt.Errorf("incorrect output metadata iterations: %d", meta.Iterations) + } + if len(meta.Salt) != DefaultSaltLength { + return fmt.Errorf("incorrect output salt length: %d", len(meta.Salt)) + } + if meta.HashFunction != DefaultHashFunctionName { + return fmt.Errorf("incorrect output hash function name: %s", meta.HashFunction) + } + return nil + }, + }, + }, + ) +} diff --git a/internal/encryption/keyprovider/pbkdf2/metadata.go b/internal/encryption/keyprovider/pbkdf2/metadata.go index eff51903f6..747aa2cfa3 100644 --- a/internal/encryption/keyprovider/pbkdf2/metadata.go +++ b/internal/encryption/keyprovider/pbkdf2/metadata.go @@ -25,18 +25,18 @@ func (m Metadata) isPresent() bool { func (m Metadata) validate() error { if m.Iterations < 0 { - return keyprovider.ErrInvalidMetadata{ + return &keyprovider.ErrInvalidMetadata{ Message: fmt.Sprintf("invalid number of iterations (%d)", m.Iterations), } } if m.KeyLength < 0 { - return keyprovider.ErrInvalidMetadata{ + return &keyprovider.ErrInvalidMetadata{ Message: fmt.Sprintf("invalid key length (%d)", m.KeyLength), } } if m.HashFunction != "" { if err := m.HashFunction.Validate(); err != nil { - return keyprovider.ErrInvalidMetadata{ + return &keyprovider.ErrInvalidMetadata{ Message: "invalid hash function name", Cause: err, } diff --git a/internal/encryption/keyprovider/pbkdf2/provider.go b/internal/encryption/keyprovider/pbkdf2/provider.go index 5d811a9827..856e16b952 100644 --- a/internal/encryption/keyprovider/pbkdf2/provider.go +++ b/internal/encryption/keyprovider/pbkdf2/provider.go @@ -39,9 +39,14 @@ func (p pbkdf2KeyProvider) generateMetadata() (*Metadata, error) { func (p pbkdf2KeyProvider) Provide(rawMeta keyprovider.KeyMeta) (keyprovider.Output, keyprovider.KeyMeta, error) { if rawMeta == nil { - return keyprovider.Output{}, nil, keyprovider.ErrInvalidMetadata{Message: "bug: no metadata struct provided"} + return keyprovider.Output{}, nil, &keyprovider.ErrInvalidMetadata{Message: "bug: no metadata struct provided"} + } + inMeta, ok := rawMeta.(*Metadata) + if !ok { + return keyprovider.Output{}, nil, &keyprovider.ErrInvalidMetadata{ + Message: fmt.Sprintf("bug: incorrect metadata type of %T provided", rawMeta), + } } - inMeta := rawMeta.(*Metadata) outMeta, err := p.generateMetadata() if err != nil { diff --git a/internal/encryption/keyprovider/pbkdf2/provider_test.go b/internal/encryption/keyprovider/pbkdf2/provider_test.go index 26ef30c562..44ce263a61 100644 --- a/internal/encryption/keyprovider/pbkdf2/provider_test.go +++ b/internal/encryption/keyprovider/pbkdf2/provider_test.go @@ -76,95 +76,6 @@ func TestBadReader(t *testing.T) { } } -func TestInvalidMetadata(t *testing.T) { - provider := pbkdf2KeyProvider{ - Config{ - randomSource: testRandomSource{t}, - Passphrase: "Hello world!", - KeyLength: 32, - Iterations: MinimumIterations, - HashFunction: SHA256HashFunctionName, - SaltLength: 12, - }, - } - - if _, _, err := provider.Provide(&Metadata{ - Iterations: 1, - Salt: []byte("Hello world!"), - HashFunction: SHA256HashFunctionName, - KeyLength: -1, - }); err == nil { - t.Fatalf("expected error") - } -} - -func TestNilMetadata(t *testing.T) { - provider := pbkdf2KeyProvider{ - Config{ - randomSource: testRandomSource{t}, - Passphrase: "Hello world!", - KeyLength: 32, - Iterations: MinimumIterations, - HashFunction: SHA256HashFunctionName, - SaltLength: 12, - }, - } - - if _, _, err := provider.Provide(nil); err == nil { - t.Fatalf("expected error") - } -} - -func TestOutputMetadata(t *testing.T) { - provider := pbkdf2KeyProvider{ - Config{ - randomSource: testRandomSource{t}, - Passphrase: "Hello world!", - KeyLength: 32, - Iterations: MinimumIterations, - HashFunction: SHA256HashFunctionName, - SaltLength: 12, - }, - } - - _, meta, err := provider.Provide(&Metadata{}) - if err != nil { - t.Fatalf("%v", err) - } - if !meta.(*Metadata).isPresent() { - t.Fatalf("result metadata is not present") - } - if err := meta.(*Metadata).validate(); err != nil { - t.Fatalf("result metadata is not valid (%v)", err) - } -} - -func TestFullCircle(t *testing.T) { - provider := pbkdf2KeyProvider{ - Config{ - randomSource: rand.Reader, - Passphrase: "Hello world!", - KeyLength: 32, - Iterations: MinimumIterations, - HashFunction: SHA256HashFunctionName, - SaltLength: 12, - }, - } - - encryptionKeys, meta, err := provider.Provide(&Metadata{}) - if err != nil { - t.Fatalf("%v", err) - } - - decryptionKeys, _, err := provider.Provide(meta) - if err != nil { - t.Fatalf("%v", err) - } - if !bytes.Equal(encryptionKeys.EncryptionKey, decryptionKeys.DecryptionKey) { - t.Fatalf("The two keys don't match: %x / %x", encryptionKeys.EncryptionKey, decryptionKeys.DecryptionKey) - } -} - func TestKeyLength(t *testing.T) { provider := pbkdf2KeyProvider{ Config{ @@ -184,23 +95,3 @@ func TestKeyLength(t *testing.T) { t.Fatalf("incorrect key length: %d", length) } } - -func TestNoDecryptionKeyOnEmptyInputMeta(t *testing.T) { - provider := pbkdf2KeyProvider{ - Config{ - randomSource: rand.Reader, - Passphrase: "Hello world!", - KeyLength: 128, - Iterations: MinimumIterations, - HashFunction: SHA256HashFunctionName, - SaltLength: 12, - }, - } - keys, _, err := provider.Provide(&Metadata{}) - if err != nil { - t.Fatalf("%v", err) - } - if len(keys.DecryptionKey) != 0 { - t.Fatalf("decryption key generated despite no input metadata") - } -} diff --git a/internal/encryption/keyprovider/static/config.go b/internal/encryption/keyprovider/static/config.go index 20ab74a624..ec5302cf99 100644 --- a/internal/encryption/keyprovider/static/config.go +++ b/internal/encryption/keyprovider/static/config.go @@ -7,7 +7,6 @@ package static import ( "encoding/hex" - "fmt" "github.com/opentofu/opentofu/internal/encryption/keyprovider" ) @@ -20,9 +19,18 @@ type Config struct { // Build will create the usable key provider. func (c Config) Build() (keyprovider.KeyProvider, keyprovider.KeyMeta, error) { + if c.Key == "" { + return nil, nil, &keyprovider.ErrInvalidConfiguration{ + Message: "Missing key", + } + } + decodedData, err := hex.DecodeString(c.Key) if err != nil { - return nil, nil, fmt.Errorf("failed to hex-decode the provided key (%w)", err) + return nil, nil, &keyprovider.ErrInvalidConfiguration{ + Message: "failed to hex-decode the provided key", + Cause: err, + } } return &staticKeyProvider{decodedData}, new(Metadata), nil diff --git a/internal/encryption/keyprovider/static/descriptor.go b/internal/encryption/keyprovider/static/descriptor.go index e4077e2ff1..b57fbf3055 100644 --- a/internal/encryption/keyprovider/static/descriptor.go +++ b/internal/encryption/keyprovider/static/descriptor.go @@ -9,10 +9,15 @@ import ( "github.com/opentofu/opentofu/internal/encryption/keyprovider" ) -func New() keyprovider.Descriptor { +func New() Descriptor { return &descriptor{} } +// Descriptor is an additional interface to allow for providing custom methods. +type Descriptor interface { + keyprovider.Descriptor +} + type descriptor struct { } diff --git a/internal/encryption/keyprovider/static/provider.go b/internal/encryption/keyprovider/static/provider.go index 46e12f804b..368b41e777 100644 --- a/internal/encryption/keyprovider/static/provider.go +++ b/internal/encryption/keyprovider/static/provider.go @@ -23,17 +23,30 @@ func (p staticKeyProvider) Provide(meta keyprovider.KeyMeta) (keyprovider.Output // but it illustrates well how you can store and retrieve metadata. We wish we could use generics to // save you the trouble of doing a type assertion, but Go does not have sufficiently advanced enough generics // to do that. + if meta == nil { + return keyprovider.Output{}, nil, &keyprovider.ErrInvalidMetadata{ + Message: "bug: nil provided as metadata", + } + } typedMeta, ok := meta.(*Metadata) if !ok { - return keyprovider.Output{}, nil, fmt.Errorf("bug: invalid metadata type received: %T", meta) + return keyprovider.Output{}, nil, &keyprovider.ErrInvalidMetadata{ + Message: fmt.Sprintf("bug: invalid metadata type received: %T", meta), + } } // Note: the Magic may be empty if OpenTofu isn't decrypting anything, make sure to account for that possibility. - if typedMeta.Magic != "" && typedMeta.Magic != magic { - return keyprovider.Output{}, nil, fmt.Errorf("corrupted data received, no or invalid magic string: %s", typedMeta.Magic) + var decryptionKey []byte + if typedMeta.Magic != "" { + decryptionKey = p.key + if typedMeta.Magic != magic { + return keyprovider.Output{}, nil, &keyprovider.ErrInvalidMetadata{ + Message: fmt.Sprintf("corrupted data received, no or invalid magic string: %s", typedMeta.Magic), + } + } } return keyprovider.Output{ EncryptionKey: p.key, - DecryptionKey: p.key, + DecryptionKey: decryptionKey, }, &Metadata{Magic: magic}, nil } diff --git a/internal/encryption/keyprovider/static/provider_test.go b/internal/encryption/keyprovider/static/provider_test.go index 1129282a29..324bd087ef 100644 --- a/internal/encryption/keyprovider/static/provider_test.go +++ b/internal/encryption/keyprovider/static/provider_test.go @@ -3,76 +3,115 @@ // Copyright (c) 2023 HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 -package static_test +package static import ( "bytes" + "fmt" "testing" - "github.com/opentofu/opentofu/internal/encryption/keyprovider" + "github.com/opentofu/opentofu/internal/encryption/keyprovider/compliancetest" - "github.com/opentofu/opentofu/internal/encryption/keyprovider/static" + "github.com/opentofu/opentofu/internal/encryption/keyprovider" ) func TestKeyProvider(t *testing.T) { - // TODO: Rework to check the expected errors and not just expectSuccess - type testCase struct { - name string - key string - expectSuccess bool - expectedData keyprovider.Output - } - - testCases := []testCase{ - { - name: "Empty", - expectSuccess: true, - expectedData: keyprovider.Output{}, + compliancetest.ComplianceTest( + t, + compliancetest.TestConfiguration[*descriptor, *Config, *Metadata, *staticKeyProvider]{ + Descriptor: New().(*descriptor), + HCLParseTestCases: map[string]compliancetest.HCLParseTestCase[*Config, *staticKeyProvider]{ + "success": { + HCL: `key_provider "static" "foo" { + key = "48656c6c6f20776f726c6421" +}`, + ValidHCL: true, + ValidBuild: true, + Validate: func(config *Config, keyProvider *staticKeyProvider) error { + if config.Key != "48656c6c6f20776f726c6421" { + return fmt.Errorf("incorrect key returned") + } + if !bytes.Equal(keyProvider.key, []byte("Hello world!")) { + return fmt.Errorf("key provider contains invalid key") + } + return nil + }, + }, + "empty": { + HCL: `key_provider "static" "foo" {}`, + ValidHCL: false, + ValidBuild: false, + }, + "bad-hex": { + HCL: `key_provider "static" "foo" { + key = "G" +}`, + ValidHCL: true, + ValidBuild: false, + }, + "bad-argument": { + HCL: `key_provider "static" "foo" { + keys = "48656c6c6f20776f726c6421" # Note the incorrect key name +}`, + ValidHCL: false, + ValidBuild: false, + }, + }, + ConfigStructTestCases: map[string]compliancetest.ConfigStructTestCase[*Config, *staticKeyProvider]{ + "empty": { + Config: &Config{ + Key: "", + }, + ValidBuild: false, + Validate: nil, + }, + }, + MetadataStructTestCases: map[string]compliancetest.MetadataStructTestCase[*Config, *Metadata]{ + "empty": { + ValidConfig: &Config{ + Key: "48656c6c6f20776f726c6421", + }, + Meta: &Metadata{}, + IsPresent: false, + IsValid: false, + }, + "invalid": { + ValidConfig: &Config{ + Key: "48656c6c6f20776f726c6421", + }, + Meta: &Metadata{ + Magic: "Invalid", + }, + IsPresent: true, + IsValid: false, + }, + "valid": { + ValidConfig: &Config{ + Key: "48656c6c6f20776f726c6421", + }, + Meta: &Metadata{ + Magic: "Hello world!", + }, + IsPresent: true, + IsValid: true, + }, + }, + ProvideTestCase: compliancetest.ProvideTestCase[*Config, *Metadata]{ + ValidConfig: &Config{ + Key: "48656c6c6f20776f726c6421", + }, + ExpectedOutput: &keyprovider.Output{ + EncryptionKey: []byte("Hello world!"), // "48656c6c6f20776f726c6421" in hex is "Hello world!" + DecryptionKey: []byte("Hello world!"), + }, + ValidateKeys: nil, + ValidateMetadata: func(meta *Metadata) error { + if meta.Magic != "Hello world!" { + return fmt.Errorf("incorrect output magic: %s", meta.Magic) + } + return nil + }, + }, }, - { - name: "InvalidInput", - key: "G", - expectSuccess: false, - }, - { - name: "Success", - key: "48656c6c6f20776f726c6421", - expectSuccess: true, - expectedData: keyprovider.Output{EncryptionKey: []byte("Hello world!"), DecryptionKey: []byte("Hello world!")}, // "48656c6c6f20776f726c6421" in hex is "Hello world!" - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - descriptor := static.New() - c := descriptor.ConfigStruct().(*static.Config) - - // Set key if provided - if tc.key != "" { - c.Key = tc.key - } - - keyProvider, keyMeta, buildErr := c.Build() - if tc.expectSuccess { - if buildErr != nil { - t.Fatalf("unexpected error: %v", buildErr) - } - - output, _, err := keyProvider.Provide(keyMeta) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !bytes.Equal(output.EncryptionKey, tc.expectedData.EncryptionKey) { - t.Fatalf("unexpected encryption key in output: got %v, want %v", output.EncryptionKey, tc.expectedData.EncryptionKey) - } - if !bytes.Equal(output.DecryptionKey, tc.expectedData.DecryptionKey) { - t.Fatalf("unexpected decryption key in output: got %v, want %v", output.DecryptionKey, tc.expectedData.EncryptionKey) - } - } else { - if buildErr == nil { - t.Fatalf("expected an error but got none") - } - } - }) - } + ) } diff --git a/internal/encryption/keyprovider/validation.go b/internal/encryption/keyprovider/validation.go new file mode 100644 index 0000000000..62e8079467 --- /dev/null +++ b/internal/encryption/keyprovider/validation.go @@ -0,0 +1,13 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package keyprovider + +import "regexp" + +// TODO is there a generalized way to regexp-check names? +var addrRe = regexp.MustCompile(`^key_provider\.([a-zA-Z_0-9-]+)\.([a-zA-Z_0-9-]+)$`) +var nameRe = regexp.MustCompile("^([a-zA-Z_0-9-]+)$") +var idRe = regexp.MustCompile("^([a-zA-Z_0-9-]+)$") diff --git a/internal/encryption/method/README.md b/internal/encryption/method/README.md index f67879f001..9033ef0e75 100644 --- a/internal/encryption/method/README.md +++ b/internal/encryption/method/README.md @@ -9,6 +9,10 @@ This folder contains the implementations for the encryption methods used in Open When you implement a method, take a look at the [aesgcm](aesgcm) method as a template. +### Testing your method (do this first!) + +Before you even go about writing a method, please set up the compliance tests. You can create a single test case that calls `compliancetest.ComplianceTest`. This test suite will run your key provider through all important compliance tests and will make sure that you are not missing anything during the implementation. + ### Implementing the descriptor The descriptor is very simple, you need to implement the [`Descriptor`](descriptor.go) interface in a type. (It does not have to be a struct.) However, make sure that the `ConfigStruct` always returns a struct with `hcl` tags on it. For more information on the `hcl` tags, see the [gohcl documentation](https://godocs.io/github.com/hashicorp/hcl/v2/gohcl). diff --git a/internal/encryption/method/addr.go b/internal/encryption/method/addr.go index 20ca8e1718..b755e1f008 100644 --- a/internal/encryption/method/addr.go +++ b/internal/encryption/method/addr.go @@ -15,6 +15,7 @@ import ( // TODO is there a generalized way to regexp-check names? var addrRe = regexp.MustCompile(`^method\.([a-zA-Z_0-9-]+)\.([a-zA-Z_0-9-]+)$`) var nameRe = regexp.MustCompile("^([a-zA-Z_0-9-]+)$") +var idRe = regexp.MustCompile("^([a-zA-Z_0-9-]+)$") // Addr is a type-alias for method address strings that identify a specific encryption method configuration. // The Addr is an opaque value. Do not perform string manipulation on it outside the functions supplied by the diff --git a/internal/encryption/method/aesgcm/aesgcm.go b/internal/encryption/method/aesgcm/aesgcm.go index 8617caa25e..3693d497cb 100644 --- a/internal/encryption/method/aesgcm/aesgcm.go +++ b/internal/encryption/method/aesgcm/aesgcm.go @@ -55,6 +55,9 @@ func (a aesgcm) Encrypt(data []byte) ([]byte, error) { // Decrypt decrypts an AES-GCM-encrypted data set. If the data set fails decryption, it returns an error. func (a aesgcm) Decrypt(data []byte) ([]byte, error) { + if len(a.decryptionKey) == 0 { + return nil, &method.ErrDecryptionKeyUnavailable{} + } result, err := handlePanic( func() ([]byte, error) { if len(data) == 0 { diff --git a/internal/encryption/method/aesgcm/compliance_test.go b/internal/encryption/method/aesgcm/compliance_test.go new file mode 100644 index 0000000000..7da0c36df4 --- /dev/null +++ b/internal/encryption/method/aesgcm/compliance_test.go @@ -0,0 +1,197 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package aesgcm + +import ( + "bytes" + "fmt" + "testing" + + "github.com/opentofu/opentofu/internal/encryption/keyprovider" + "github.com/opentofu/opentofu/internal/encryption/method/compliancetest" +) + +func TestCompliance(t *testing.T) { + compliancetest.ComplianceTest(t, compliancetest.TestConfiguration[*descriptor, *Config, *aesgcm]{ + Descriptor: New().(*descriptor), + HCLParseTestCases: map[string]compliancetest.HCLParseTestCase[*descriptor, *Config, *aesgcm]{ + "empty": { + HCL: `method "aes_gcm" "foo" {}`, + ValidHCL: false, + ValidBuild: false, + Validate: nil, + }, + "empty_keys": { + HCL: `method "aes_gcm" "foo" { + keys = { + encryption_key = [] + decryption_key = [] + } + }`, + ValidHCL: true, + ValidBuild: false, + Validate: nil, + }, + "short-keys": { + HCL: `method "aes_gcm" "foo" { + keys = { + encryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] + decryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] + } + }`, + ValidHCL: true, + ValidBuild: false, + Validate: nil, + }, + "short-decryption-key": { + HCL: `method "aes_gcm" "foo" { + keys = { + encryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] + decryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] + } + }`, + ValidHCL: true, + ValidBuild: false, + Validate: nil, + }, + "short-encryption-key": { + HCL: `method "aes_gcm" "foo" { + keys = { + encryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] + decryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] + } + }`, + ValidHCL: true, + ValidBuild: false, + Validate: nil, + }, + "only-decryption-key": { + HCL: `method "aes_gcm" "foo" { + keys = { + encryption_key = [] + decryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] + } + }`, + ValidHCL: true, + ValidBuild: false, + }, + "only-encryption-key": { + HCL: `method "aes_gcm" "foo" { + keys = { + encryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] + decryption_key = [] + } + }`, + ValidHCL: true, + ValidBuild: true, + Validate: func(config *Config, method *aesgcm) error { + if len(config.Keys.DecryptionKey) > 0 { + return fmt.Errorf("decryption key found in config despite no decryption key being provided") + } + if len(method.decryptionKey) > 0 { + return fmt.Errorf("decryption key found in method despite no decryption key being provided") + } + if !bytes.Equal(config.Keys.EncryptionKey, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) { + return fmt.Errorf("incorrect encryption key found after HCL parsing in config") + } + if !bytes.Equal(method.encryptionKey, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) { + return fmt.Errorf("incorrect encryption key found after HCL parsing in config") + } + return nil + }, + }, + "encryption-decryption-key": { + HCL: `method "aes_gcm" "foo" { + keys = { + encryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] + decryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] + } + }`, + ValidHCL: true, + ValidBuild: true, + Validate: func(config *Config, method *aesgcm) error { + if !bytes.Equal(config.Keys.DecryptionKey, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) { + return fmt.Errorf("incorrect decryption key found after HCL parsing in config") + } + if !bytes.Equal(method.decryptionKey, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) { + return fmt.Errorf("incorrect decryption key found after HCL parsing in config") + } + + if !bytes.Equal(config.Keys.EncryptionKey, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) { + return fmt.Errorf("incorrect encryption key found after HCL parsing in config") + } + if !bytes.Equal(method.encryptionKey, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) { + return fmt.Errorf("incorrect encryption key found after HCL parsing in config") + } + return nil + }, + }, + "no-aad": { + HCL: `method "aes_gcm" "foo" { + keys = { + encryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] + decryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] + } + }`, + ValidHCL: true, + ValidBuild: true, + Validate: func(config *Config, method *aesgcm) error { + if len(config.AAD) != 0 { + return fmt.Errorf("invalid AAD in config after HCL parsing") + } + if len(method.aad) != 0 { + return fmt.Errorf("invalid AAD in method after Build()") + } + return nil + }, + }, + "aad": { + HCL: `method "aes_gcm" "foo" { + keys = { + encryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] + decryption_key = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] + } + aad = [1,2,3,4] + }`, + ValidHCL: true, + ValidBuild: true, + Validate: func(config *Config, method *aesgcm) error { + if !bytes.Equal(config.AAD, []byte{1, 2, 3, 4}) { + return fmt.Errorf("invalid AAD in config after HCL parsing") + } + if !bytes.Equal(method.aad, []byte{1, 2, 3, 4}) { + return fmt.Errorf("invalid AAD in method after Build()") + } + return nil + }, + }, + }, + ConfigStructTestCases: map[string]compliancetest.ConfigStructTestCase[*Config, *aesgcm]{ + "empty": { + Config: &Config{ + Keys: keyprovider.Output{}, + AAD: nil, + }, + ValidBuild: false, + Validate: nil, + }, + }, + EncryptDecryptTestCase: compliancetest.EncryptDecryptTestCase[*Config, *aesgcm]{ + ValidEncryptOnlyConfig: &Config{ + Keys: keyprovider.Output{ + EncryptionKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + DecryptionKey: nil, + }, + }, + ValidFullConfig: &Config{ + Keys: keyprovider.Output{ + EncryptionKey: []byte{17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}, + DecryptionKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + }, + }, + }, + }) +} diff --git a/internal/encryption/method/aesgcm/config.go b/internal/encryption/method/aesgcm/config.go index 40efc187dc..e6d56cf527 100644 --- a/internal/encryption/method/aesgcm/config.go +++ b/internal/encryption/method/aesgcm/config.go @@ -36,11 +36,6 @@ func (c *Config) Build() (method.Method, error) { encryptionKey := c.Keys.EncryptionKey decryptionKey := c.Keys.DecryptionKey - if len(decryptionKey) == 0 { - // Use encryption key as decryption key if missing - decryptionKey = encryptionKey - } - if !validKeyLengths.Has(len(encryptionKey)) { return nil, &method.ErrInvalidConfiguration{ Cause: fmt.Errorf( @@ -51,13 +46,15 @@ func (c *Config) Build() (method.Method, error) { } } - if !validKeyLengths.Has(len(decryptionKey)) { - return nil, &method.ErrInvalidConfiguration{ - Cause: fmt.Errorf( - "AES-GCM requires the key length to be one of: %s, received %d bytes in the decryption key", - validKeyLengths.String(), - len(decryptionKey), - ), + if len(decryptionKey) > 0 { + if !validKeyLengths.Has(len(decryptionKey)) { + return nil, &method.ErrInvalidConfiguration{ + Cause: fmt.Errorf( + "AES-GCM requires the key length to be one of: %s, received %d bytes in the decryption key", + validKeyLengths.String(), + len(decryptionKey), + ), + } } } diff --git a/internal/encryption/method/aesgcm/config_test.go b/internal/encryption/method/aesgcm/config_test.go index a23b803053..b5827fe384 100644 --- a/internal/encryption/method/aesgcm/config_test.go +++ b/internal/encryption/method/aesgcm/config_test.go @@ -84,19 +84,6 @@ func TestConfig_Build(t *testing.T) { }, errorType: &method.ErrInvalidConfiguration{}, }, - { - name: "decryption-key-fallback", - config: &Config{ - Keys: keyprovider.Output{ - EncryptionKey: []byte("bohwu9zoo7Zooe16"), - }, - }, - errorType: nil, - expected: aesgcm{ - encryptionKey: []byte("bohwu9zoo7Zooe16"), - decryptionKey: []byte("bohwu9zoo7Zooe16"), - }, - }, { name: "aad", config: &Config{ diff --git a/internal/encryption/method/compliancetest/compliance.go b/internal/encryption/method/compliancetest/compliance.go new file mode 100644 index 0000000000..f05b5ce1b7 --- /dev/null +++ b/internal/encryption/method/compliancetest/compliance.go @@ -0,0 +1,321 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package compliancetest + +import ( + "bytes" + "errors" + "reflect" + "testing" + + "github.com/hashicorp/hcl/v2/gohcl" + "github.com/opentofu/opentofu/internal/encryption/compliancetest" + "github.com/opentofu/opentofu/internal/encryption/config" + "github.com/opentofu/opentofu/internal/encryption/method" +) + +// ComplianceTest tests the functionality of a method to make sure it conforms to the expectations of the method +// interface. +func ComplianceTest[TDescriptor method.Descriptor, TConfig method.Config, TMethod method.Method]( + t *testing.T, + testConfig TestConfiguration[TDescriptor, TConfig, TMethod], +) { + testConfig.execute(t) +} + +type TestConfiguration[TDescriptor method.Descriptor, TConfig method.Config, TMethod method.Method] struct { + Descriptor TDescriptor + // HCLParseTestCases contains the test cases of parsing HCL configuration and then validating it using the Build() + // function. + HCLParseTestCases map[string]HCLParseTestCase[TDescriptor, TConfig, TMethod] + + // ConfigStructT validates that a certain config results or does not result in a valid Build() call. + ConfigStructTestCases map[string]ConfigStructTestCase[TConfig, TMethod] + + // ProvideTestCase exercises the entire chain and generates two keys. + EncryptDecryptTestCase EncryptDecryptTestCase[TConfig, TMethod] +} + +func (cfg *TestConfiguration[TDescriptor, TConfig, TMethod]) execute(t *testing.T) { + t.Run("id", func(t *testing.T) { + cfg.testID(t) + }) + t.Run("hcl", func(t *testing.T) { + cfg.testHCL(t) + }) + t.Run("config-struct", func(t *testing.T) { + cfg.testConfigStruct(t) + }) + t.Run("encrypt-decrypt", func(t *testing.T) { + cfg.EncryptDecryptTestCase.execute(t) + }) +} + +func (cfg *TestConfiguration[TDescriptor, TConfig, TMethod]) testID(t *testing.T) { + id := cfg.Descriptor.ID() + if err := id.Validate(); err != nil { + compliancetest.Fail(t, "Invalid ID returned from method descriptor: %s (%v)", id, err) + } else { + compliancetest.Log(t, "The ID provided by the method descriptor is valid: %s", id) + } +} + +func (cfg *TestConfiguration[TDescriptor, TConfig, TMethod]) testHCL(t *testing.T) { + if cfg.HCLParseTestCases == nil { + compliancetest.Fail(t, "Please provide a map to HCLParseTestCases.") + } + hasInvalidHCL := false + hasValidHCLInvalidBuild := false + hasValidBuild := false + for name, tc := range cfg.HCLParseTestCases { + tc := tc + if !tc.ValidHCL { + hasInvalidHCL = true + } else { + if tc.ValidBuild { + hasValidBuild = true + } else { + hasValidHCLInvalidBuild = true + } + } + t.Run(name, func(t *testing.T) { + tc.execute(t, cfg.Descriptor) + }) + } + t.Run("completeness", func(t *testing.T) { + if !hasInvalidHCL { + compliancetest.Fail(t, "Please provide at least one test case with an invalid HCL.") + } + if !hasValidHCLInvalidBuild { + compliancetest.Fail(t, "Please provide at least one test case with a valid HCL that fails on Build()") + } + if !hasValidBuild { + compliancetest.Fail( + t, + "Please provide at least one test case with a valid HCL that succeeds on Build()", + ) + } + }) +} + +func (cfg *TestConfiguration[TDescriptor, TConfig, TMethod]) testConfigStruct(t *testing.T) { + compliancetest.ConfigStruct[TConfig](t, cfg.Descriptor.ConfigStruct()) + + if cfg.ConfigStructTestCases == nil { + compliancetest.Fail(t, "Please provide a map to ConfigStructTestCases.") + } + + for name, tc := range cfg.ConfigStructTestCases { + tc := tc + t.Run(name, func(t *testing.T) { + tc.execute(t) + }) + } +} + +// HCLParseTestCase contains a test case that parses HCL into a configuration. +type HCLParseTestCase[TDescriptor method.Descriptor, TConfig method.Config, TMethod method.Method] struct { + // HCL contains the code that should be parsed into the configuration structure. + HCL string + // ValidHCL indicates that the HCL block should be parsable into the configuration structure, but not necessarily + // result in a valid Build() call. + ValidHCL bool + // ValidBuild indicates that calling the Build() function should not result in an error. + ValidBuild bool + // Validate is an extra optional validation function that can check if the configuration contains the correct + // values parsed from HCL. If ValidBuild is true, the method will be passed as well. + Validate func(config TConfig, method TMethod) error +} + +func (h *HCLParseTestCase[TDescriptor, TConfig, TMethod]) execute(t *testing.T, descriptor TDescriptor) { + parseError := false + parsedConfig, diags := config.LoadConfigFromString("config.hcl", h.HCL) + if h.ValidHCL { + if diags.HasErrors() { + compliancetest.Fail(t, "Unexpected HCL error (%v).", diags) + } else { + compliancetest.Log(t, "HCL successfully parsed.") + } + } else { + if diags.HasErrors() { + parseError = true + } + } + + configStruct := descriptor.ConfigStruct() + diags = gohcl.DecodeBody( + parsedConfig.MethodConfigs[0].Body, + nil, + configStruct, + ) + var m TMethod + if h.ValidHCL { + if diags.HasErrors() { + compliancetest.Fail(t, "Failed to parse empty HCL block into config struct (%v).", diags) + } else { + compliancetest.Log(t, "HCL successfully loaded into config struct.") + } + + m = buildConfigAndValidate[TMethod](t, configStruct, h.ValidBuild) + } else { + if !parseError && !diags.HasErrors() { + compliancetest.Fail(t, "Expected error during HCL parsing, but no error was returned.") + } else { + compliancetest.Log(t, "HCL loading errored correctly (%v).", diags) + } + } + + if h.Validate != nil { + if err := h.Validate(configStruct.(TConfig), m); err != nil { + compliancetest.Fail(t, "Error during validation and configuration (%v).", err) + } else { + compliancetest.Log(t, "Successfully validated parsed HCL config and applied modifications.") + } + } else { + compliancetest.Log(t, "No ValidateAndConfigure provided, skipping HCL parse validation.") + } +} + +// ConfigStructTestCase validates that the config struct is behaving correctly when Build() is called. +type ConfigStructTestCase[TConfig method.Config, TMethod method.Method] struct { + Config TConfig + ValidBuild bool + Validate func(method TMethod) error +} + +func (m ConfigStructTestCase[TConfig, TMethod]) execute(t *testing.T) { + newMethod := buildConfigAndValidate[TMethod, TConfig](t, m.Config, m.ValidBuild) + if m.Validate != nil { + if err := m.Validate(newMethod); err != nil { + compliancetest.Fail(t, "method validation failed (%v)", err) + } + } +} + +// EncryptDecryptTestCase handles a full encryption-decryption cycle. +type EncryptDecryptTestCase[TConfig method.Config, TMethod method.Method] struct { + // ValidEncryptOnlyConfig is a configuration that has no decryption key and can only be used for encryption. The + // key must match ValidFullConfig. + ValidEncryptOnlyConfig TConfig + // ValidFullConfig is a configuration that contains both an encryption and decryption key. + ValidFullConfig TConfig +} + +func (m EncryptDecryptTestCase[TConfig, TMethod]) execute(t *testing.T) { + if reflect.ValueOf(m.ValidEncryptOnlyConfig).IsNil() { + compliancetest.Fail(t, "Please provide a ValidEncryptOnlyConfig to EncryptDecryptTestCase.") + } + if reflect.ValueOf(m.ValidFullConfig).IsNil() { + compliancetest.Fail(t, "Please provide a ValidFullConfig to EncryptDecryptTestCase.") + } + + encryptMethod := buildConfigAndValidate[TMethod, TConfig](t, m.ValidEncryptOnlyConfig, true) + decryptMethod := buildConfigAndValidate[TMethod, TConfig](t, m.ValidFullConfig, true) + + plainData := []byte("Hello world!") + encryptedData, err := encryptMethod.Encrypt(plainData) + if err != nil { + compliancetest.Fail(t, "Unexpected error after Encrypt() on the encrypt-only method (%v).", err) + } + + _, err = encryptMethod.Decrypt(encryptedData) + if err == nil { + compliancetest.Fail(t, "Decrypt() did not fail without a decryption key.") + } else { + compliancetest.Log(t, "Decrypt() returned an error with a decryption key.") + } + var noDecryptionKeyError *method.ErrDecryptionKeyUnavailable + if !errors.As(err, &noDecryptionKeyError) { + compliancetest.Fail(t, "Decrypt() returned a %T instead of a %T without a decryption key. Please use the correct typed errors.", err, noDecryptionKeyError) + } else { + compliancetest.Log(t, "Decrypt() returned the correct error type of %T without a decryption key.", noDecryptionKeyError) + } + + _, err = decryptMethod.Decrypt([]byte{}) + if err == nil { + compliancetest.Fail(t, "Decrypt() must return an error when decrypting empty data, no error returned.") + } else { + compliancetest.Log(t, "Decrypt() correctly returned an error when decrypting empty data.") + } + var typedDecryptError *method.ErrDecryptionFailed + if !errors.As(err, &typedDecryptError) { + compliancetest.Fail(t, "Decrypt() returned a %T instead of a %T when decrypting empty data. Please use the correct typed errors.", err, typedDecryptError) + } else { + compliancetest.Log(t, "Decrypt() returned the correct error type of %T when decrypting empty data.", typedDecryptError) + } + typedDecryptError = nil + + _, err = decryptMethod.Decrypt(plainData) + if err == nil { + compliancetest.Fail(t, "Decrypt() must return an error when decrypting unencrypted data, no error returned.") + } else { + compliancetest.Log(t, "Decrypt() correctly returned an error when decrypting unencrypted data.") + } + if !errors.As(err, &typedDecryptError) { + compliancetest.Fail(t, "Decrypt() returned a %T instead of a %T when decrypting unencrypted data. Please use the correct typed errors.", err, typedDecryptError) + } else { + compliancetest.Log(t, "Decrypt() returned the correct error type of %T when decrypting unencrypted data.", typedDecryptError) + } + + decryptedData, err := decryptMethod.Decrypt(encryptedData) + if err != nil { + compliancetest.Fail(t, "Decrypt() failed to decrypt previously-encrypted data (%v).", err) + } else { + compliancetest.Log(t, "Decrypt() succeeded.") + } + + if !bytes.Equal(decryptedData, plainData) { + compliancetest.Fail(t, "Decrypt() returned incorrect plain text data:\n%v\nexpected:\n%v", decryptedData, plainData) + } else { + compliancetest.Log(t, "Decrypt() returned the correct plain text data.") + } +} + +func buildConfigAndValidate[TMethod method.Method, TConfig method.Config]( + t *testing.T, + configStruct TConfig, + validBuild bool, +) TMethod { + if reflect.ValueOf(configStruct).IsNil() { + compliancetest.Fail(t, "Nil struct passed!") + } + + var typedMethod TMethod + var ok bool + kp, err := configStruct.Build() + if validBuild { + if err != nil { + compliancetest.Fail(t, "Build() returned an unexpected error: %v.", err) + } else { + compliancetest.Log(t, "Build() did not return an error.") + } + typedMethod, ok = kp.(TMethod) + if !ok { + compliancetest.Fail(t, "Build() returned an invalid method type of %T, expected %T", kp, typedMethod) + } else { + compliancetest.Log(t, "Build() returned the correct method type of %T.", typedMethod) + } + } else { + if err == nil { + compliancetest.Fail(t, "Build() did not return an error.") + } else { + compliancetest.Log(t, "Build() correctly returned an error: %v", err) + } + + var typedError *method.ErrInvalidConfiguration + if !errors.As(err, &typedError) { + compliancetest.Fail( + t, + "Build() did not return the correct error type, got %T but expected %T", + err, + typedError, + ) + } else { + compliancetest.Log(t, "Build() returned the correct error type of %T", typedError) + } + } + return typedMethod +} diff --git a/internal/encryption/method/errors.go b/internal/encryption/method/errors.go index aa2c8adb61..565f1a437b 100644 --- a/internal/encryption/method/errors.go +++ b/internal/encryption/method/errors.go @@ -52,6 +52,14 @@ func (e ErrDecryptionFailed) Unwrap() error { return e.Cause } +// ErrDecryptionKeyUnavailable indicates that no decryption key is available. +type ErrDecryptionKeyUnavailable struct { +} + +func (e ErrDecryptionKeyUnavailable) Error() string { + return "no decryption key available" +} + // ErrInvalidConfiguration indicates that the method configuration is incorrect. type ErrInvalidConfiguration struct { Cause error diff --git a/internal/encryption/method/id.go b/internal/encryption/method/id.go index 08626dfe65..b79669826d 100644 --- a/internal/encryption/method/id.go +++ b/internal/encryption/method/id.go @@ -5,11 +5,20 @@ package method +import ( + "fmt" +) + // ID is a type alias to make passing the wrong ID into a method ID harder. type ID string // Validate validates the key provider ID for correctness. func (i ID) Validate() error { - // TODO implement format checking + if i == "" { + return fmt.Errorf("empty key provider ID (key provider IDs must match %s)", idRe.String()) + } + if !idRe.MatchString(string(i)) { + return fmt.Errorf("invalid key provider ID: %s (must match %s)", i, idRe.String()) + } return nil } diff --git a/internal/encryption/methods.go b/internal/encryption/methods.go index 6bac31ee07..a325a251bd 100644 --- a/internal/encryption/methods.go +++ b/internal/encryption/methods.go @@ -50,7 +50,7 @@ func (e *targetBuilder) setupMethod(cfg config.MethodConfig) hcl.Diagnostics { } // Lookup the definition of the encryption method from the registry - encryptionMethod, err := e.reg.GetMethod(method.ID(cfg.Type)) + encryptionMethod, err := e.reg.GetMethodDescriptor(method.ID(cfg.Type)) if err != nil { // Handle if the method was not found diff --git a/internal/encryption/registry/compliancetest/compliance.go b/internal/encryption/registry/compliancetest/compliance.go new file mode 100644 index 0000000000..7f790dc37a --- /dev/null +++ b/internal/encryption/registry/compliancetest/compliance.go @@ -0,0 +1,29 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package compliancetest + +import ( + "testing" + + "github.com/opentofu/opentofu/internal/encryption/registry" +) + +func ComplianceTest(t *testing.T, factory func() registry.Registry) { + t.Run("returns-registry", func(t *testing.T) { + reg := factory() + if reg == nil { + t.Fatalf("Calling the factory method did not return a valid registry.") + } + }) + + t.Run("key_provider", func(t *testing.T) { + complianceTestKeyProviders(t, factory) + }) + + t.Run("method", func(t *testing.T) { + complianceTestMethods(t, factory) + }) +} diff --git a/internal/encryption/registry/compliancetest/compliance_key_provider.go b/internal/encryption/registry/compliancetest/compliance_key_provider.go new file mode 100644 index 0000000000..0ca31a8904 --- /dev/null +++ b/internal/encryption/registry/compliancetest/compliance_key_provider.go @@ -0,0 +1,137 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package compliancetest + +import ( + "errors" + "testing" + + "github.com/opentofu/opentofu/internal/encryption/keyprovider" + "github.com/opentofu/opentofu/internal/encryption/registry" +) + +func complianceTestKeyProviders(t *testing.T, factory func() registry.Registry) { + t.Run("registration-and-return", func(t *testing.T) { + complianceTestKeyProviderRegistrationAndReturn(t, factory) + }) + t.Run("register-invalid-id", func(t *testing.T) { + complianceTestKeyProviderInvalidID(t, factory) + }) + t.Run("duplicate-registration", func(t *testing.T) { + complianceTestKeyProviderDuplicateRegistration(t, factory) + }) +} + +func complianceTestKeyProviderRegistrationAndReturn(t *testing.T, factory func() registry.Registry) { + reg := factory() + testKeyProvider := &testKeyProviderDescriptor{ + "test", + } + if err := reg.RegisterKeyProvider(testKeyProvider); err != nil { + t.Fatalf("Failed to register test key provider with ID %s (%v)", testKeyProvider.id, err) + } + returnedKeyProvider, err := reg.GetKeyProviderDescriptor(testKeyProvider.id) + if err != nil { + t.Fatalf("The previously registered key provider with the ID %s couldn't be fetched from the registry (%v).", testKeyProvider.id, err) + } + returnedTypedKeyProvider, ok := returnedKeyProvider.(*testKeyProviderDescriptor) + if !ok { + t.Fatalf("The returned key provider was not of the expected type of %T, but instead it was %T.", testKeyProvider, returnedKeyProvider) + } + if returnedTypedKeyProvider.id != testKeyProvider.id { + t.Fatalf("The returned key provider contained the wrong ID %s instead of %s", returnedTypedKeyProvider.id, testKeyProvider.id) + } + + _, err = reg.GetKeyProviderDescriptor("nonexistent") + if err == nil { + t.Fatalf("Requesting a non-existent key provider from GetKeyProviderDescriptor did not return an error.") + } + var typedErr *registry.KeyProviderNotFoundError + if !errors.As(err, &typedErr) { + t.Fatalf( + "Requesting a non-existent key provider from GetKeyProviderDescriptor returned an incorrect error type of %T. This function should always return a *registry.KeyProviderNotFoundError if the key provider was not found.", + err, + ) + } +} + +func complianceTestKeyProviderInvalidID(t *testing.T, factory func() registry.Registry) { + reg := factory() + testKeyProvider := &testKeyProviderDescriptor{ + "Hello world!", + } + err := reg.RegisterKeyProvider(testKeyProvider) + if err == nil { + t.Fatalf("Registering a key provider with the invalid ID of %s did not result in an error.", testKeyProvider.id) + } + var typedErr *registry.InvalidKeyProviderError + if !errors.As(err, &typedErr) { + t.Fatalf( + "Registering a key provider with an invalid ID of %s resulted in an error of type %T instead of %T. Please make sure to use the correct typed errors.", + testKeyProvider.id, + err, + typedErr, + ) + } +} + +func complianceTestKeyProviderDuplicateRegistration(t *testing.T, factory func() registry.Registry) { + reg := factory() + testKeyProvider := &testKeyProviderDescriptor{ + "test", + } + testKeyProvider2 := &testKeyProviderDescriptor{ + "test", + } + if err := reg.RegisterKeyProvider(testKeyProvider); err != nil { + t.Fatalf("Failed to register test key provider with ID %s (%v)", testKeyProvider.id, err) + } + err := reg.RegisterKeyProvider(testKeyProvider) + if err == nil { + t.Fatalf("Re-registering the same key provider again did not result in an error.") + } + var typedErr *registry.KeyProviderAlreadyRegisteredError + if !errors.As(err, &typedErr) { + t.Fatalf( + "Re-registering the same key provider twice resulted in an error of the type %T instead of %T. Please make sure to use the correct typed errors.", + err, + typedErr, + ) + } + + err = reg.RegisterKeyProvider(testKeyProvider2) + if err == nil { + t.Fatalf("Re-registering the a provider with a duplicate ID did not result in an error.") + } + if !errors.As(err, &typedErr) { + t.Fatalf( + "Re-registering the a key provider with a duplicate ID resulted in an error of the type %T instead of %T. Please make sure to use the correct typed errors.", + err, + typedErr, + ) + } +} + +type testKeyProviderDescriptor struct { + id keyprovider.ID +} + +func (t testKeyProviderDescriptor) ID() keyprovider.ID { + return t.id +} + +func (t testKeyProviderDescriptor) ConfigStruct() keyprovider.Config { + return &testKeyProviderConfigStruct{} +} + +type testKeyProviderConfigStruct struct { +} + +func (t testKeyProviderConfigStruct) Build() (keyprovider.KeyProvider, keyprovider.KeyMeta, error) { + return nil, nil, keyprovider.ErrInvalidConfiguration{ + Message: "The Build() function is not implemented on the testKeyProviderConfigStruct", + } +} diff --git a/internal/encryption/registry/compliancetest/compliance_method.go b/internal/encryption/registry/compliancetest/compliance_method.go new file mode 100644 index 0000000000..c6650c5032 --- /dev/null +++ b/internal/encryption/registry/compliancetest/compliance_method.go @@ -0,0 +1,138 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package compliancetest + +import ( + "errors" + "fmt" + "testing" + + "github.com/opentofu/opentofu/internal/encryption/method" + "github.com/opentofu/opentofu/internal/encryption/registry" +) + +func complianceTestMethods(t *testing.T, factory func() registry.Registry) { + t.Run("registration-and-return", func(t *testing.T) { + complianceTestMethodRegistrationAndReturn(t, factory) + }) + t.Run("register-invalid-id", func(t *testing.T) { + complianceTestMethodInvalidID(t, factory) + }) + t.Run("duplicate-registration", func(t *testing.T) { + complianceTestMethodDuplicateRegistration(t, factory) + }) +} + +func complianceTestMethodRegistrationAndReturn(t *testing.T, factory func() registry.Registry) { + reg := factory() + testMethod := &testMethodDescriptor{ + "test", + } + if err := reg.RegisterMethod(testMethod); err != nil { + t.Fatalf("Failed to register test method with ID %s (%v)", testMethod.id, err) + } + returnedMethod, err := reg.GetMethodDescriptor(testMethod.id) + if err != nil { + t.Fatalf("The previously registered method with the ID %s couldn't be fetched from the registry (%v).", testMethod.id, err) + } + returnedTypedMethod, ok := returnedMethod.(*testMethodDescriptor) + if !ok { + t.Fatalf("The returned method was not of the expected type of %T, but instead it was %T.", testMethod, returnedMethod) + } + if returnedTypedMethod.id != testMethod.id { + t.Fatalf("The returned method contained the wrong ID %s instead of %s", returnedTypedMethod.id, testMethod.id) + } + + _, err = reg.GetMethodDescriptor("nonexistent") + if err == nil { + t.Fatalf("Requesting a non-existent method from GetMethodDescriptor did not return an error.") + } + var typedErr *registry.MethodNotFoundError + if !errors.As(err, &typedErr) { + t.Fatalf( + "Requesting a non-existent method from GetMethodDescriptor returned an incorrect error type of %T. This function should always return a *registry.MethodNotFoundError if the method was not found.", + err, + ) + } +} + +func complianceTestMethodInvalidID(t *testing.T, factory func() registry.Registry) { + reg := factory() + testMethod := &testMethodDescriptor{ + "Hello world!", + } + err := reg.RegisterMethod(testMethod) + if err == nil { + t.Fatalf("Registering a method with the invalid ID of %s did not result in an error.", testMethod.id) + } + var typedErr *registry.InvalidMethodError + if !errors.As(err, &typedErr) { + t.Fatalf( + "Registering a method with an invalid ID of %s resulted in an error of type %T instead of %T. Please make sure to use the correct typed errors.", + testMethod.id, + err, + typedErr, + ) + } +} + +func complianceTestMethodDuplicateRegistration(t *testing.T, factory func() registry.Registry) { + reg := factory() + testMethod := &testMethodDescriptor{ + "test", + } + testMethod2 := &testMethodDescriptor{ + "test", + } + if err := reg.RegisterMethod(testMethod); err != nil { + t.Fatalf("Failed to register test method with ID %s (%v)", testMethod.id, err) + } + err := reg.RegisterMethod(testMethod) + if err == nil { + t.Fatalf("Re-registering the same method again did not result in an error.") + } + var typedErr *registry.MethodAlreadyRegisteredError + if !errors.As(err, &typedErr) { + t.Fatalf( + "Re-registering the same method twice resulted in an error of the type %T instead of %T. Please make sure to use the correct typed errors.", + err, + typedErr, + ) + } + + err = reg.RegisterMethod(testMethod2) + if err == nil { + t.Fatalf("Re-registering the a provider with a duplicate ID did not result in an error.") + } + if !errors.As(err, &typedErr) { + t.Fatalf( + "Re-registering the a method with a duplicate ID resulted in an error of the type %T instead of %T. Please make sure to use the correct typed errors.", + err, + typedErr, + ) + } +} + +type testMethodDescriptor struct { + id method.ID +} + +func (t testMethodDescriptor) ID() method.ID { + return t.id +} + +func (t testMethodDescriptor) ConfigStruct() method.Config { + return &testMethodConfig{} +} + +type testMethodConfig struct { +} + +func (t testMethodConfig) Build() (method.Method, error) { + return nil, method.ErrInvalidConfiguration{ + Cause: fmt.Errorf("build not implemented for test method"), + } +} diff --git a/internal/encryption/registry/lockingencryptionregistry/new.go b/internal/encryption/registry/lockingencryptionregistry/registry.go similarity index 95% rename from internal/encryption/registry/lockingencryptionregistry/new.go rename to internal/encryption/registry/lockingencryptionregistry/registry.go index 264d3146a7..dd6c847eaf 100644 --- a/internal/encryption/registry/lockingencryptionregistry/new.go +++ b/internal/encryption/registry/lockingencryptionregistry/registry.go @@ -68,7 +68,7 @@ func (l *lockingRegistry) GetKeyProviderDescriptor(id keyprovider.ID) (keyprovid return provider, nil } -func (l *lockingRegistry) GetMethod(id method.ID) (method.Descriptor, error) { +func (l *lockingRegistry) GetMethodDescriptor(id method.ID) (method.Descriptor, error) { l.lock.RLock() defer l.lock.RUnlock() foundMethod, ok := l.methods[id] diff --git a/internal/encryption/registry/lockingencryptionregistry/registry_test.go b/internal/encryption/registry/lockingencryptionregistry/registry_test.go new file mode 100644 index 0000000000..dbfec80f7b --- /dev/null +++ b/internal/encryption/registry/lockingencryptionregistry/registry_test.go @@ -0,0 +1,17 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package lockingencryptionregistry_test + +import ( + "testing" + + "github.com/opentofu/opentofu/internal/encryption/registry/compliancetest" + "github.com/opentofu/opentofu/internal/encryption/registry/lockingencryptionregistry" +) + +func TestCompliance(t *testing.T) { + compliancetest.ComplianceTest(t, lockingencryptionregistry.New) +} diff --git a/internal/encryption/registry/registry.go b/internal/encryption/registry/registry.go index b95a2e3d5c..3aa5950e66 100644 --- a/internal/encryption/registry/registry.go +++ b/internal/encryption/registry/registry.go @@ -21,11 +21,11 @@ type Registry interface { // already registered. RegisterMethod(method method.Descriptor) error - // GetKeyProvider returns the key provider with the specified ID. If the key provider is not registered, + // GetKeyProviderDescriptor returns the key provider with the specified ID. If the key provider is not registered, // it will return a *KeyProviderNotFoundError error. GetKeyProviderDescriptor(id keyprovider.ID) (keyprovider.Descriptor, error) - // GetMethod returns the method with the specified ID. + // GetMethodDescriptor returns the method with the specified ID. // If the method is not registered, it will return a *MethodNotFoundError. - GetMethod(id method.ID) (method.Descriptor, error) + GetMethodDescriptor(id method.ID) (method.Descriptor, error) }