Initial implementation of aws_kms encryption.key_provider (#1349)

Signed-off-by: Christian Mesh <christianmesh1@gmail.com>
Signed-off-by: James Humphries <james@james-humphries.co.uk>
Co-authored-by: James Humphries <james@james-humphries.co.uk>
This commit is contained in:
Christian Mesh 2024-03-13 13:19:20 -04:00 committed by GitHub
parent 0a747533bf
commit 07a9185767
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 3004 additions and 95 deletions

1
go.mod
View File

@ -20,6 +20,7 @@ require (
github.com/aws/aws-sdk-go-v2 v1.23.2
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.6
github.com/aws/aws-sdk-go-v2/service/dynamodb v1.25.5
github.com/aws/aws-sdk-go-v2/service/kms v1.26.5
github.com/aws/aws-sdk-go-v2/service/s3 v1.46.0
github.com/bgentry/speakeasy v0.1.0
github.com/bmatcuk/doublestar/v4 v4.6.0

2
go.sum
View File

@ -341,6 +341,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.5 h1:F+XafeiK7
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.5/go.mod h1:NlZuvlkyu6l/F3+qIBsGGtYLL2Z71tCf5NFoNAaG1NY=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.16.5 h1:ow5dalHqYM8IbzXFCL86gQY9UJUtZsLyBHUd6OKep9M=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.16.5/go.mod h1:AcvGHLN2pTXdx1oVFSzcclBvfY2VbBg0AfOE/XjA7oo=
github.com/aws/aws-sdk-go-v2/service/kms v1.26.5 h1:MRNoQVbEtjzhYFeKVMifHae4K5q4FuK9B7tTDskIF/g=
github.com/aws/aws-sdk-go-v2/service/kms v1.26.5/go.mod h1:gfe6e+rOxaiz/gr5Myk83ruBD6F9WvM7TZbLjcTNsDM=
github.com/aws/aws-sdk-go-v2/service/s3 v1.46.0 h1:RaXPp86CLxTKDwCwSTmTW7FvTfaLPXhN48mPtQ881bA=
github.com/aws/aws-sdk-go-v2/service/s3 v1.46.0/go.mod h1:x7gN1BRfTWXdPr/cFGM/iz+c87gRtJ+JMYinObt/0LI=
github.com/aws/aws-sdk-go-v2/service/sqs v1.28.4 h1:Hy1cUZGuZRHe3HPxw7nfA9BFUqdWbyI0JLLiqENgucc=

View File

@ -57,6 +57,7 @@ type Backend struct {
// ConfigSchema returns a description of the expected configuration
// structure for the receiving backend.
// This structure is mirrored by the encryption aws_kms key provider and should be kept in sync.
func (b *Backend) ConfigSchema() *configschema.Block {
return &configschema.Block{
Attributes: map[string]*configschema.Attribute{

View File

@ -6,6 +6,7 @@
package encryption
import (
"github.com/opentofu/opentofu/internal/encryption/keyprovider/aws_kms"
"github.com/opentofu/opentofu/internal/encryption/keyprovider/pbkdf2"
"github.com/opentofu/opentofu/internal/encryption/method/aesgcm"
"github.com/opentofu/opentofu/internal/encryption/registry/lockingencryptionregistry"
@ -20,4 +21,7 @@ func init() {
if err := DefaultRegistry.RegisterMethod(aesgcm.New()); err != nil {
panic(err)
}
if err := DefaultRegistry.RegisterKeyProvider(aws_kms.New()); err != nil {
panic(err)
}
}

View File

@ -13,10 +13,9 @@ import (
"github.com/opentofu/opentofu/internal/encryption/config"
"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/gohcl"
"github.com/opentofu/opentofu/internal/encryption/keyprovider"
"github.com/opentofu/opentofu/internal/encryption/registry"
"github.com/opentofu/opentofu/internal/varhcl"
"github.com/opentofu/opentofu/internal/gohcl"
"github.com/zclconf/go-cty/cty"
)
@ -107,7 +106,7 @@ func (e *targetBuilder) setupKeyProvider(cfg config.KeyProviderConfig, stack []c
keyProviderConfig := keyProviderDescriptor.ConfigStruct()
// Locate all the dependencies
deps, diags := varhcl.VariablesInBody(cfg.Body, keyProviderConfig)
deps, diags := gohcl.VariablesInBody(cfg.Body, keyProviderConfig)
if diags.HasErrors() {
return diags
}

View File

@ -0,0 +1,38 @@
# AWS KMS Key Provider
> [!WARNING]
> This file is not an end-user documentation, it is intended for developers. Please follow the user documentation on the OpenTofu website unless you want to work on the encryption code.
This folder contains the code for the AWS KMS Key Provider. The user will be able to provide a reference to an AWS KMS key which can be used to encrypt and decrypt the data.
## Configuration
You can configure this key provider by specifying the following options:
```hcl2
terraform {
encryption {
key_provider "aws_kms" "myprovider" {
kms_key_id = "1234abcd-12ab-34cd-56ef-1234567890ab"
}
}
}
```
## Key Provider Options - kms_key_id
The kms_key_id can refer to one of the following:
- Key ID: 1234abcd-12ab-34cd-56ef-1234567890ab
- Key ARN: arn:aws:kms:us-east-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab
- Alias name: alias/ExampleAlias
- Alias ARN: arn:aws:kms:us-east-2:111122223333:alias/ExampleAlias
For more information see https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/kms#GenerateDataKeyInput
## State Snapshotting and Key Usage
### Overview
OpenTofu generates a new encryption key for every time we store encrypted data, ensuring high security by minimizing key reuse.
This has some minor cost implications that should be communicated to the end users, There may be more keys generated than expected as OpenTofu uses a new key for each state snapshot.
It is important to generate a new key for each state snapshot to ensure that the state snapshot is encrypted with a unique key instead of reusing the same key for all state snapshots and thus reducing the security of the system.

View File

@ -0,0 +1,242 @@
package aws_kms
import (
"context"
"fmt"
"os"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/aws/aws-sdk-go-v2/service/kms/types"
awsbase "github.com/hashicorp/aws-sdk-go-base/v2"
baselogging "github.com/hashicorp/aws-sdk-go-base/v2/logging"
"github.com/opentofu/opentofu/internal/encryption/keyprovider"
"github.com/opentofu/opentofu/internal/httpclient"
"github.com/opentofu/opentofu/internal/logging"
"github.com/opentofu/opentofu/version"
)
type Config struct {
// KeyProvider Config
KMSKeyID string `hcl:"kms_key_id"`
KeySpec string `hcl:"key_spec"`
// Mirrored S3 Backend Config, mirror any changes
AccessKey string `hcl:"access_key,optional"`
Endpoints []ConfigEndpoints `hcl:"endpoints,block"`
MaxRetries int `hcl:"max_retries,optional"`
Profile string `hcl:"profile,optional"`
Region string `hcl:"region,optional"`
SecretKey string `hcl:"secret_key,optional"`
SkipCredsValidation bool `hcl:"skip_credentials_validation,optional"`
SkipRequestingAccountId bool `hcl:"skip_requesting_account_id,optional"`
STSRegion string `hcl:"sts_region,optional"`
Token string `hcl:"token,optional"`
HTTPProxy *string `hcl:"http_proxy,optional"`
HTTPSProxy *string `hcl:"https_proxy,optional"`
NoProxy string `hcl:"no_proxy,optional"`
Insecure bool `hcl:"insecure,optional"`
UseDualStackEndpoint bool `hcl:"use_dualstack_endpoint,optional"`
UseFIPSEndpoint bool `hcl:"use_fips_endpoint,optional"`
CustomCABundle string `hcl:"custom_ca_bundle,optional"`
EC2MetadataServiceEndpoint string `hcl:"ec2_metadata_service_endpoint,optional"`
EC2MetadataServiceEndpointMode string `hcl:"ec2_metadata_service_endpoint_mode,optional"`
SkipMetadataAPICheck *bool `hcl:"skip_metadata_api_check,optional"`
SharedCredentialsFiles []string `hcl:"shared_credentials_files,optional"`
SharedConfigFiles []string `hcl:"shared_config_files,optional"`
AssumeRole *AssumeRole `hcl:"assume_role,optional"`
AssumeRoleWithWebIdentity *AssumeRoleWithWebIdentity `hcl:"assume_role_with_web_identity,optional"`
AllowedAccountIds []string `hcl:"allowed_account_ids,optional"`
ForbiddenAccountIds []string `hcl:"forbidden_account_ids,optional"`
RetryMode string `hcl:"retry_mode,optional"`
}
func stringAttrEnvFallback(val string, env string) string {
if val != "" {
return val
}
return os.Getenv(env)
}
func stringArrayAttrEnvFallback(val []string, env string) []string {
if len(val) != 0 {
return val
}
envVal := os.Getenv(env)
if envVal != "" {
return []string{envVal}
}
return nil
}
func (c Config) asAWSBase() (*awsbase.Config, error) {
// Get endpoints to use
endpoints, err := c.getEndpoints()
if err != nil {
return nil, err
}
// Get assume role
assumeRole, err := c.AssumeRole.asAWSBase()
if err != nil {
return nil, err
}
// Get assume role with web identity
assumeRoleWithWebIdentity, err := c.AssumeRoleWithWebIdentity.asAWSBase()
if err != nil {
return nil, err
}
// validate region
if c.Region == "" && os.Getenv("AWS_REGION") == "" && os.Getenv("AWS_DEFAULT_REGION") == "" {
return nil, fmt.Errorf(`the "region" attribute or the "AWS_REGION" or "AWS_DEFAULT_REGION" environment variables must be set.`)
}
// Retry Mode
if c.MaxRetries == 0 {
c.MaxRetries = 5
}
var retryMode aws.RetryMode
if len(c.RetryMode) != 0 {
retryMode, err = aws.ParseRetryMode(c.RetryMode)
if err != nil {
return nil, fmt.Errorf("%w: expected %q or %q", err, aws.RetryModeStandard, aws.RetryModeAdaptive)
}
}
// IDMS handling
imdsEnabled := imds.ClientDefaultEnableState
if c.SkipMetadataAPICheck != nil {
if *c.SkipMetadataAPICheck {
imdsEnabled = imds.ClientEnabled
} else {
imdsEnabled = imds.ClientDisabled
}
}
// validate account_ids
if len(c.AllowedAccountIds) != 0 && len(c.ForbiddenAccountIds) != 0 {
return nil, fmt.Errorf("conflicting config attributes: only allowed_account_ids or forbidden_account_ids can be specified, not both")
}
return &awsbase.Config{
AccessKey: c.AccessKey,
CallerDocumentationURL: "https://opentofu.org/docs/language/settings/backends/s3", // TODO
CallerName: "KMS Key Provider",
IamEndpoint: stringAttrEnvFallback(endpoints.IAM, "AWS_ENDPOINT_URL_IAM"),
MaxRetries: c.MaxRetries,
RetryMode: retryMode,
Profile: c.Profile,
Region: c.Region,
SecretKey: c.SecretKey,
SkipCredsValidation: c.SkipCredsValidation,
SkipRequestingAccountId: c.SkipRequestingAccountId,
StsEndpoint: stringAttrEnvFallback(endpoints.STS, "AWS_ENDPOINT_URL_STS"),
StsRegion: c.STSRegion,
Token: c.Token,
// Note: we don't need to read env variables explicitly because they are read implicitly by aws-sdk-base-go:
// see: https://github.com/hashicorp/aws-sdk-go-base/blob/v2.0.0-beta.41/internal/config/config.go#L133
// which relies on: https://cs.opensource.google/go/x/net/+/refs/tags/v0.18.0:http/httpproxy/proxy.go;l=89-96
HTTPProxy: c.HTTPProxy,
HTTPSProxy: c.HTTPSProxy,
NoProxy: c.NoProxy,
Insecure: c.Insecure,
UseDualStackEndpoint: c.UseDualStackEndpoint,
UseFIPSEndpoint: c.UseFIPSEndpoint,
UserAgent: awsbase.UserAgentProducts{
{Name: "APN", Version: "1.0"},
{Name: httpclient.DefaultApplicationName, Version: version.String()},
},
CustomCABundle: stringAttrEnvFallback(c.CustomCABundle, "AWS_CA_BUNDLE"),
EC2MetadataServiceEnableState: imdsEnabled,
EC2MetadataServiceEndpoint: stringAttrEnvFallback(c.EC2MetadataServiceEndpoint, "AWS_EC2_METADATA_SERVICE_ENDPOINT"),
EC2MetadataServiceEndpointMode: stringAttrEnvFallback(c.EC2MetadataServiceEndpointMode, "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE"),
SharedCredentialsFiles: stringArrayAttrEnvFallback(c.SharedCredentialsFiles, "AWS_SHARED_CREDENTIALS_FILE"),
SharedConfigFiles: stringArrayAttrEnvFallback(c.SharedConfigFiles, "AWS_SHARED_CONFIG_FILE"),
AssumeRole: assumeRole,
AssumeRoleWithWebIdentity: assumeRoleWithWebIdentity,
AllowedAccountIds: c.AllowedAccountIds,
ForbiddenAccountIds: c.ForbiddenAccountIds,
}, nil
}
func (c Config) Build() (keyprovider.KeyProvider, keyprovider.KeyMeta, error) {
err := c.validate()
if err != nil {
return nil, nil, err
}
cfg, err := c.asAWSBase()
if err != nil {
return nil, nil, err
}
ctx := context.Background()
ctx, baselog := attachLoggerToContext(ctx)
cfg.Logger = baselog
_, awsConfig, awsDiags := awsbase.GetAwsConfig(ctx, cfg)
if awsDiags.HasError() {
out := "errors were encountered in aws kms configuration"
for _, diag := range awsDiags.Errors() {
out += "\n" + diag.Summary() + " : " + diag.Detail()
}
return nil, nil, fmt.Errorf(out)
}
return &keyProvider{
Config: c,
svc: kms.NewFromConfig(awsConfig),
ctx: ctx,
}, new(keyMeta), nil
}
// validate checks the configuration for the key provider
func (c Config) validate() (err error) {
if c.KMSKeyID == "" {
return &keyprovider.ErrInvalidConfiguration{
Message: "no kms_key_id provided",
}
}
if c.KeySpec == "" {
return &keyprovider.ErrInvalidConfiguration{
Message: "no key_spec provided",
}
}
spec := c.getKeySpecAsAWSType()
if spec == nil {
return &keyprovider.ErrInvalidConfiguration{
Message: fmt.Sprintf("invalid key_spec %s, expected one of %v", c.KeySpec, spec.Values()),
}
}
return nil
}
// getSpecAsAWSType handles conversion between the string from the config and the aws expected enum type
// it will return nil if it cannot find a match
func (c Config) getKeySpecAsAWSType() *types.DataKeySpec {
var spec types.DataKeySpec
for _, opt := range spec.Values() {
if string(opt) == c.KeySpec {
spec = opt
}
}
return &spec
}
// Mirrored from s3 backend config
func attachLoggerToContext(ctx context.Context) (context.Context, baselogging.HcLogger) {
ctx, baseLog := baselogging.NewHcLogger(ctx, logging.HCLogger().Named("backend-s3"))
ctx = baselogging.RegisterLogger(ctx, baseLog)
return ctx, baseLog
}

View File

@ -0,0 +1,118 @@
package aws_kms
import (
"fmt"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws/arn"
awsbase "github.com/hashicorp/aws-sdk-go-base/v2"
)
type AssumeRole struct {
RoleARN string `hcl:"role_arn"`
Duration string `hcl:"duration,optional"`
ExternalID string `hcl:"external_id,optional"`
Policy string `hcl:"policy,optional"`
PolicyARNs []string `hcl:"policy_arns,optional"`
SessionName string `hcl:"session_name,optional"`
Tags map[string]string `hcl:"tags,optional"`
TransitiveTagKeys []string `hcl:"transitive_tag_keys,optional"`
}
type AssumeRoleWithWebIdentity struct {
RoleARN string `hcl:"role_arn,optional"`
Duration string `hcl:"duration,optional"`
Policy string `hcl:"policy,optional"`
PolicyARNs []string `hcl:"policy_arns,optional"`
SessionName string `hcl:"session_name,optional"`
WebIdentityToken string `hcl:"web_identity_token,optional"`
WebIdentityTokenFile string `hcl:"web_identity_token_file,optional"`
}
func parseAssumeRoleDuration(val string) (dur time.Duration, err error) {
if len(val) == 0 {
return dur, nil
}
dur, err = time.ParseDuration(val)
if err != nil {
return dur, fmt.Errorf("invalid assume_role duration %q: %w", val, err)
}
minDur := 15 * time.Minute
maxDur := 12 * time.Hour
if (minDur > 0 && dur < minDur) || (maxDur > 0 && dur > maxDur) {
return dur, fmt.Errorf("assume_role duration must be between %s and %s, had %s", minDur, maxDur, dur)
}
return dur, nil
}
func validatePolicyARNs(arns []string) error {
for _, v := range arns {
arn, err := arn.Parse(v)
if err != nil {
return err
}
if !strings.HasPrefix(arn.Resource, "policy/") {
return fmt.Errorf("arn must be a valid IAM Policy ARN, got %q", v)
}
}
return nil
}
func (r *AssumeRole) asAWSBase() (*awsbase.AssumeRole, error) {
if r == nil {
return nil, nil
}
duration, err := parseAssumeRoleDuration(r.Duration)
if err != nil {
return nil, err
}
err = validatePolicyARNs(r.PolicyARNs)
if err != nil {
return nil, err
}
assumeRole := &awsbase.AssumeRole{
RoleARN: r.RoleARN,
Duration: duration,
ExternalID: r.ExternalID,
Policy: r.Policy,
PolicyARNs: r.PolicyARNs,
SessionName: r.SessionName,
Tags: r.Tags,
TransitiveTagKeys: r.TransitiveTagKeys,
}
return assumeRole, nil
}
func (r *AssumeRoleWithWebIdentity) asAWSBase() (*awsbase.AssumeRoleWithWebIdentity, error) {
if r == nil {
return nil, nil
}
if r.WebIdentityToken != "" && r.WebIdentityTokenFile != "" {
return nil, fmt.Errorf("conflicting config attributes: only web_identity_token or web_identity_token_file can be specified, not both")
}
duration, err := parseAssumeRoleDuration(r.Duration)
if err != nil {
return nil, err
}
err = validatePolicyARNs(r.PolicyARNs)
if err != nil {
return nil, err
}
return &awsbase.AssumeRoleWithWebIdentity{
RoleARN: stringAttrEnvFallback(r.RoleARN, "AWS_ROLE_ARN"),
Duration: duration,
Policy: r.Policy,
PolicyARNs: r.PolicyARNs,
SessionName: stringAttrEnvFallback(r.SessionName, "AWS_ROLE_SESSION_NAME"),
WebIdentityToken: stringAttrEnvFallback(r.WebIdentityToken, "AWS_WEB_IDENTITY_TOKEN"),
WebIdentityTokenFile: stringAttrEnvFallback(r.WebIdentityTokenFile, "AWS_WEB_IDENTITY_TOKEN_FILE"),
}, nil
}

View File

@ -0,0 +1,42 @@
package aws_kms
import (
"fmt"
"log"
"regexp"
)
type ConfigEndpoints struct {
IAM string `hcl:"iam,optional"`
STS string `hcl:"sts,optional"`
}
// Mirrored from s3 backend config
func includeProtoIfNessesary(endpoint string) string {
if matched, _ := regexp.MatchString("[a-z]*://.*", endpoint); !matched {
log.Printf("[DEBUG] Adding https:// prefix to endpoint '%s'", endpoint)
endpoint = fmt.Sprintf("https://%s", endpoint)
}
return endpoint
}
func (c Config) getEndpoints() (ConfigEndpoints, error) {
endpoints := ConfigEndpoints{}
// Make sure we have 0 or 1 endpoint blocks
if len(c.Endpoints) == 1 {
endpoints = c.Endpoints[0]
}
if len(c.Endpoints) > 1 {
return endpoints, fmt.Errorf("expected single aws_kms endpoints block, multiple provided")
}
// Endpoint formatting
if len(endpoints.IAM) != 0 {
endpoints.IAM = includeProtoIfNessesary(endpoints.IAM)
}
if len(endpoints.STS) != 0 {
endpoints.STS = includeProtoIfNessesary(endpoints.STS)
}
return endpoints, nil
}

View File

@ -0,0 +1,245 @@
package aws_kms
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/davecgh/go-spew/spew"
awsbase "github.com/hashicorp/aws-sdk-go-base/v2"
"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/hclsyntax"
"github.com/opentofu/opentofu/internal/gohcl"
"github.com/opentofu/opentofu/internal/httpclient"
"github.com/opentofu/opentofu/version"
)
func TestConfig_asAWSBase(t *testing.T) {
testCases := []struct {
name string
input string
expected awsbase.Config
}{
{
name: "minconfig",
input: `
kms_key_id = "my-kms-key-id"
key_spec = "AES_256"
region = "magic-mountain"`,
expected: awsbase.Config{
Region: "magic-mountain",
CallerDocumentationURL: "https://opentofu.org/docs/language/settings/backends/s3",
CallerName: "KMS Key Provider",
MaxRetries: 5,
UserAgent: awsbase.UserAgentProducts{
{Name: "APN", Version: "1.0"},
{Name: httpclient.DefaultApplicationName, Version: version.String()},
},
},
},
{
name: "maxconfig",
input: `
kms_key_id = "my-kms-key-id"
key_spec = "AES_256"
access_key = "my-access-key"
endpoints {
iam = "endpoint-iam"
sts = "endpoint-sts"
}
max_retries = 42
profile = "my-profile"
region = "my-region"
secret_key = "my-secret-key"
skip_credentials_validation = true
skip_requesting_account_id = true
sts_region = "my-sts-region"
token = "my-token"
http_proxy = "my-http-proxy"
https_proxy = "my-https-proxy"
no_proxy = "my-noproxy"
insecure = true
use_dualstack_endpoint = true
use_fips_endpoint = true
custom_ca_bundle = "my-custom-ca-bundle"
ec2_metadata_service_endpoint = "my-emde"
ec2_metadata_service_endpoint_mode = "my-emde-mode"
skip_metadata_api_check = false
shared_credentials_files = ["my-scredf"]
shared_config_files = ["my-sconff"]
assume_role = {
role_arn = "ar_arn"
duration = "4h"
external_id = "ar_extid"
policy = "ar_policy"
policy_arns = ["arn:aws:iam::123456789012:policy/AR"]
session_name = "ar_session_name"
tags = {
foo = "bar"
}
transitive_tag_keys = ["ar_tags"]
}
assume_role_with_web_identity = {
role_arn = "wi_arn"
duration = "5h"
policy = "wi_policy"
policy_arns = ["arn:aws:iam::123456789012:policy/WI"]
session_name = "wi_session_name"
web_identity_token = "wi_token"
//web_identity_token_file = "wi_token_file"
}
allowed_account_ids = ["account"]
//forbidden_account_ids = ?
retry_mode = "adaptive"
`,
expected: awsbase.Config{
CallerDocumentationURL: "https://opentofu.org/docs/language/settings/backends/s3",
CallerName: "KMS Key Provider",
UserAgent: awsbase.UserAgentProducts{
{Name: "APN", Version: "1.0"},
{Name: httpclient.DefaultApplicationName, Version: version.String()},
},
AccessKey: "my-access-key",
IamEndpoint: "https://endpoint-iam",
MaxRetries: 42,
Profile: "my-profile",
Region: "my-region",
SecretKey: "my-secret-key",
SkipCredsValidation: true,
SkipRequestingAccountId: true,
StsEndpoint: "https://endpoint-sts",
StsRegion: "my-sts-region",
Token: "my-token",
HTTPProxy: aws.String("my-http-proxy"),
HTTPSProxy: aws.String("my-https-proxy"),
NoProxy: "my-noproxy",
Insecure: true,
UseDualStackEndpoint: true,
UseFIPSEndpoint: true,
CustomCABundle: "my-custom-ca-bundle",
EC2MetadataServiceEnableState: imds.ClientDisabled,
EC2MetadataServiceEndpoint: "my-emde",
EC2MetadataServiceEndpointMode: "my-emde-mode",
SharedCredentialsFiles: []string{"my-scredf"},
SharedConfigFiles: []string{"my-sconff"},
AssumeRole: &awsbase.AssumeRole{
RoleARN: "ar_arn",
Duration: time.Hour * 4,
ExternalID: "ar_extid",
Policy: "ar_policy",
PolicyARNs: []string{
"arn:aws:iam::123456789012:policy/AR",
},
SessionName: "ar_session_name",
Tags: map[string]string{
"foo": "bar",
},
TransitiveTagKeys: []string{
"ar_tags",
},
},
AssumeRoleWithWebIdentity: &awsbase.AssumeRoleWithWebIdentity{
RoleARN: "wi_arn",
Duration: time.Hour * 5,
Policy: "wi_policy",
PolicyARNs: []string{
"arn:aws:iam::123456789012:policy/WI",
},
SessionName: "wi_session_name",
WebIdentityToken: "wi_token",
WebIdentityTokenFile: "",
},
AllowedAccountIds: []string{"account"},
RetryMode: aws.RetryModeAdaptive,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
input, diags := hclsyntax.ParseConfig([]byte(tc.input), "test", hcl.InitialPos)
if diags.HasErrors() {
t.Fatal(diags.Error())
}
config := new(Config)
diags = gohcl.DecodeBody(input.Body, nil, config)
if diags.HasErrors() {
t.Fatal(diags.Error())
}
if config.KMSKeyID != "my-kms-key-id" {
t.Fatal("missing kms_key_id")
}
if config.KeySpec != "AES_256" {
t.Fatal("missing key_spec")
}
actual, err := config.asAWSBase()
if err != nil {
t.Fatal(err.Error())
}
if !reflect.DeepEqual(tc.expected, *actual) {
t.Fatalf("Expected %s, got %s", spew.Sdump(tc.expected), spew.Sdump(*actual))
}
})
}
}
func TestValidate(t *testing.T) {
testCases := []struct {
name string
input Config
expected error
}{
{
name: "valid",
input: Config{
KMSKeyID: "my-kms-key-id",
KeySpec: "AES_256",
},
expected: nil,
},
{
name: "missing kms_key_id",
input: Config{
KMSKeyID: "",
KeySpec: "AES_256",
},
expected: fmt.Errorf("no kms_key_id provided"),
},
{
name: "missing key_spec",
input: Config{
KMSKeyID: "my-kms-key-id",
KeySpec: "",
},
expected: fmt.Errorf("no key_spec provided"),
},
{
name: "invalid key_spec",
input: Config{
KMSKeyID: "my-kms-key-id",
KeySpec: "invalid??",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.input.validate()
// check if the error message is the same
if tc.expected != nil {
if err.Error() != tc.expected.Error() {
t.Fatalf("Expected %q, got %q", tc.expected.Error(), err.Error())
}
}
})
}
}

View File

@ -0,0 +1,20 @@
package aws_kms
import (
"github.com/opentofu/opentofu/internal/encryption/keyprovider"
)
func New() keyprovider.Descriptor {
return &descriptor{}
}
type descriptor struct {
}
func (f descriptor) ID() keyprovider.ID {
return "aws_kms"
}
func (f descriptor) ConfigStruct() keyprovider.Config {
return &Config{}
}

View File

@ -0,0 +1,73 @@
package aws_kms
import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/aws/aws-sdk-go-v2/service/kms/types"
"github.com/opentofu/opentofu/internal/encryption/keyprovider"
)
type keyMeta struct {
CiphertextBlob []byte `json:"ciphertext_blob"`
}
func (m keyMeta) isPresent() bool {
return len(m.CiphertextBlob) != 0
}
type keyProvider struct {
Config
svc *kms.Client
ctx context.Context
}
func (p keyProvider) Provide(rawMeta keyprovider.KeyMeta) (keyprovider.Output, keyprovider.KeyMeta, error) {
if rawMeta == nil {
return keyprovider.Output{}, nil, keyprovider.ErrInvalidMetadata{Message: "bug: no metadata struct provided"}
}
inMeta := rawMeta.(*keyMeta)
outMeta := &keyMeta{}
out := keyprovider.Output{}
// as validation has happened in the config, we can safely cast here and not worry about the cast failing
spec := types.DataKeySpec(p.KeySpec)
generatedKeyData, err := p.svc.GenerateDataKey(p.ctx, &kms.GenerateDataKeyInput{
KeyId: aws.String(p.KMSKeyID),
KeySpec: spec,
})
if err != nil {
return out, outMeta, &keyprovider.ErrKeyProviderFailure{
Message: "failed to generate key",
Cause: err,
}
}
// Set initial outputs that are always set
out.EncryptionKey = generatedKeyData.Plaintext
outMeta.CiphertextBlob = generatedKeyData.CiphertextBlob
// We do not set the DecryptionKey here as we should only be setting the decryption key if we are decrypting
// and that is handled below when we check if the inMeta has a CiphertextBlob
if inMeta.isPresent() {
// We have an existing decryption key to decrypt, so we should now populate the DecryptionKey
decryptedKeyData, decryptErr := p.svc.Decrypt(p.ctx, &kms.DecryptInput{
KeyId: aws.String(p.KMSKeyID),
CiphertextBlob: inMeta.CiphertextBlob,
})
if decryptErr != nil {
return out, outMeta, decryptErr
}
// Set decryption key on the output
out.DecryptionKey = decryptedKeyData.Plaintext
}
return out, outMeta, nil
}

View File

@ -0,0 +1,180 @@
package aws_kms
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/aws/aws-sdk-go-v2/service/kms/types"
awsbase "github.com/hashicorp/aws-sdk-go-base/v2"
)
// skipCheck checks if the test should be skipped or not based on environment variables
func skipCheck(t *testing.T) {
// check if TF_ACC and TF_KMS_TEST are unset
// if so, skip the test
if os.Getenv("TF_ACC") == "" && os.Getenv("TF_KMS_TEST") == "" {
t.Log("Skipping test because TF_ACC or TF_KMS_TEST is not set")
t.Skip()
}
}
const testKeyPrefix = "tf-acc-test-kms-key"
const testAliasPrefix = "alias/my-key-alias"
func TestKMSProvider_Simple(t *testing.T) {
skipCheck(t)
ctx := context.TODO()
keyName := fmt.Sprintf("%s-%x", testKeyPrefix, time.Now().Unix())
alias := fmt.Sprintf("%s-%x", testAliasPrefix, time.Now().Unix())
// Constructs a aws kms key provider config that accepts the alias as the key id
providerConfig := Config{
KMSKeyID: alias,
KeySpec: "AES_256",
}
// Mimic the creation of the aws client here via providerConfig.asAWSBase() so that
// we create a key in the same way that it will be read
awsBaseConfig, err := providerConfig.asAWSBase()
if err != nil {
t.Fatalf("Error creating AWS config: %s", err)
}
_, awsConfig, awsDiags := awsbase.GetAwsConfig(ctx, awsBaseConfig)
if awsDiags.HasError() {
t.Fatalf("Error creating AWS config: %v", awsDiags)
}
kmsClient := kms.NewFromConfig(awsConfig)
// Create the key
keyId := createKMSKey(ctx, t, kmsClient, keyName, awsBaseConfig.Region)
defer scheduleKMSKeyDeletion(ctx, t, kms.NewFromConfig(awsConfig), keyId)
// Create an alias for the key
createAlias(ctx, t, kmsClient, keyId, &alias)
defer deleteAlias(ctx, t, kms.NewFromConfig(awsConfig), &alias)
// Now that we have the config, we can build the provider
provider, metaIn, err := providerConfig.Build()
if err != nil {
t.Fatalf("Error building provider: %s", err)
}
// Now we can test the provider
output, meta, err := provider.Provide(metaIn)
if err != nil {
t.Fatalf("Error providing keys: %s", err)
}
if len(output.EncryptionKey) == 0 {
t.Fatalf("No encryption key provided")
}
if len(output.DecryptionKey) != 0 {
t.Fatalf("Decryption key provided and should not be")
}
if len(meta.(*keyMeta).CiphertextBlob) == 0 {
t.Fatalf("No ciphertext blob provided")
}
t.Log("Continue to meta -> decryption key")
// Now that we have a encyption key and it's meta, let's get the decryption key
output, meta, err = provider.Provide(meta)
if err != nil {
t.Fatalf("Error providing keys: %s", err)
}
if len(output.EncryptionKey) == 0 {
t.Fatalf("No encryption key provided")
}
if len(output.DecryptionKey) == 0 {
t.Fatalf("No decryption key provided")
}
if len(meta.(*keyMeta).CiphertextBlob) == 0 {
t.Fatalf("No ciphertext blob provided")
}
}
// createKMSKey creates a KMS key with the given name and region
func createKMSKey(ctx context.Context, t *testing.T, kmsClient *kms.Client, keyName string, region string) (keyID string) {
createKeyReq := kms.CreateKeyInput{
Tags: []types.Tag{
{
TagKey: aws.String("Name"),
TagValue: aws.String(keyName),
},
},
}
t.Logf("Creating KMS key %s in %s", keyName, region)
created, err := kmsClient.CreateKey(ctx, &createKeyReq)
if err != nil {
t.Fatalf("Error creating KMS key: %s", err)
}
return *created.KeyMetadata.KeyId
}
// createAlias creates a KMS alias for the given key
func createAlias(ctx context.Context, t *testing.T, kmsClient *kms.Client, keyID string, alias *string) {
if alias == nil {
return
}
t.Logf("Creating KMS alias %s for key %s", *alias, keyID)
aliasReq := kms.CreateAliasInput{
AliasName: aws.String(*alias),
TargetKeyId: aws.String(keyID),
}
_, err := kmsClient.CreateAlias(ctx, &aliasReq)
if err != nil {
t.Fatalf("Error creating KMS alias: %s", err)
}
}
// scheduleKMSKeyDeletion schedules the deletion of a KMS key
// this attempts to delete it in the fastest possible way (7 days)
func scheduleKMSKeyDeletion(ctx context.Context, t *testing.T, kmsClient *kms.Client, keyID string) {
deleteKeyReq := kms.ScheduleKeyDeletionInput{
KeyId: aws.String(keyID),
PendingWindowInDays: aws.Int32(7),
}
t.Logf("Scheduling KMS key %s for deletion", keyID)
_, err := kmsClient.ScheduleKeyDeletion(ctx, &deleteKeyReq)
if err != nil {
t.Fatalf("Error deleting KMS key: %s", err)
}
}
// deleteAlias deletes a KMS alias
func deleteAlias(ctx context.Context, t *testing.T, kmsClient *kms.Client, alias *string) {
if alias == nil {
return
}
t.Logf("Deleting KMS alias %s", *alias)
deleteAliasReq := kms.DeleteAliasInput{
AliasName: aws.String(*alias),
}
_, err := kmsClient.DeleteAlias(ctx, &deleteAliasReq)
if err != nil {
t.Fatalf("Error deleting KMS alias: %s", err)
}
}

View File

@ -27,6 +27,7 @@ type ErrInvalidConfiguration struct {
func (e ErrInvalidConfiguration) Error() string {
if e.Cause != nil {
if e.Message != "" {
return fmt.Sprintf("%s: %v", e.Message, e.Cause)
}

1
internal/gohcl/README.md Normal file
View File

@ -0,0 +1 @@
## This is a temporary fork of the github.com/hashicorp/hcl/gohcl/v2 package. It is in the process of being upstreamed

441
internal/gohcl/decode.go Normal file
View File

@ -0,0 +1,441 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package gohcl
import (
"fmt"
"reflect"
"github.com/zclconf/go-cty/cty"
"github.com/hashicorp/hcl/v2"
"github.com/zclconf/go-cty/cty/convert"
"github.com/zclconf/go-cty/cty/gocty"
)
// DecodeBody extracts the configuration within the given body into the given
// value. This value must be a non-nil pointer to either a struct or
// a map, where in the former case the configuration will be decoded using
// struct tags and in the latter case only attributes are allowed and their
// values are decoded into the map.
//
// The given EvalContext is used to resolve any variables or functions in
// expressions encountered while decoding. This may be nil to require only
// constant values, for simple applications that do not support variables or
// functions.
//
// The returned diagnostics should be inspected with its HasErrors method to
// determine if the populated value is valid and complete. If error diagnostics
// are returned then the given value may have been partially-populated but
// may still be accessed by a careful caller for static analysis and editor
// integration use-cases.
func DecodeBody(body hcl.Body, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics {
rv := reflect.ValueOf(val)
if rv.Kind() != reflect.Ptr {
panic(fmt.Sprintf("target value must be a pointer, not %s", rv.Type().String()))
}
return decodeBodyToValue(body, ctx, rv.Elem())
}
func decodeBodyToValue(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) hcl.Diagnostics {
et := val.Type()
switch et.Kind() {
case reflect.Struct:
return decodeBodyToStruct(body, ctx, val)
case reflect.Map:
return decodeBodyToMap(body, ctx, val)
default:
panic(fmt.Sprintf("target value must be pointer to struct or map, not %s", et.String()))
}
}
func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) hcl.Diagnostics {
schema, partial := ImpliedBodySchema(val.Interface())
var content *hcl.BodyContent
var leftovers hcl.Body
var diags hcl.Diagnostics
if partial {
content, leftovers, diags = body.PartialContent(schema)
} else {
content, diags = body.Content(schema)
}
if content == nil {
return diags
}
tags := getFieldTags(val.Type())
if tags.Body != nil {
fieldIdx := *tags.Body
field := val.Type().Field(fieldIdx)
fieldV := val.Field(fieldIdx)
switch {
case bodyType.AssignableTo(field.Type):
fieldV.Set(reflect.ValueOf(body))
default:
diags = append(diags, decodeBodyToValue(body, ctx, fieldV)...)
}
}
if tags.Remain != nil {
fieldIdx := *tags.Remain
field := val.Type().Field(fieldIdx)
fieldV := val.Field(fieldIdx)
switch {
case bodyType.AssignableTo(field.Type):
fieldV.Set(reflect.ValueOf(leftovers))
case attrsType.AssignableTo(field.Type):
attrs, attrsDiags := leftovers.JustAttributes()
if len(attrsDiags) > 0 {
diags = append(diags, attrsDiags...)
}
fieldV.Set(reflect.ValueOf(attrs))
default:
diags = append(diags, decodeBodyToValue(leftovers, ctx, fieldV)...)
}
}
for name, fieldIdx := range tags.Attributes {
attr := content.Attributes[name]
field := val.Type().Field(fieldIdx)
fieldV := val.Field(fieldIdx)
if attr == nil {
if !exprType.AssignableTo(field.Type) {
continue
}
// As a special case, if the target is of type hcl.Expression then
// we'll assign an actual expression that evalues to a cty null,
// so the caller can deal with it within the cty realm rather
// than within the Go realm.
synthExpr := hcl.StaticExpr(cty.NullVal(cty.DynamicPseudoType), body.MissingItemRange())
fieldV.Set(reflect.ValueOf(synthExpr))
continue
}
switch {
case attrType.AssignableTo(field.Type):
fieldV.Set(reflect.ValueOf(attr))
case exprType.AssignableTo(field.Type):
fieldV.Set(reflect.ValueOf(attr.Expr))
case field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct:
// TODO might want to check for nil here
rn := reflect.New(field.Type.Elem())
fieldV.Set(rn)
diags = append(diags, DecodeExpression(
attr.Expr, ctx, fieldV.Interface(),
)...)
default:
diags = append(diags, DecodeExpression(
attr.Expr, ctx, fieldV.Addr().Interface(),
)...)
}
}
blocksByType := content.Blocks.ByType()
for typeName, fieldIdx := range tags.Blocks {
blocks := blocksByType[typeName]
field := val.Type().Field(fieldIdx)
ty := field.Type
isSlice := false
isPtr := false
if ty.Kind() == reflect.Slice {
isSlice = true
ty = ty.Elem()
}
if ty.Kind() == reflect.Ptr {
isPtr = true
ty = ty.Elem()
}
if len(blocks) > 1 && !isSlice {
diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: fmt.Sprintf("Duplicate %s block", typeName),
Detail: fmt.Sprintf(
"Only one %s block is allowed. Another was defined at %s.",
typeName, blocks[0].DefRange.String(),
),
Subject: &blocks[1].DefRange,
})
continue
}
if len(blocks) == 0 {
if isSlice || isPtr {
if val.Field(fieldIdx).IsNil() {
val.Field(fieldIdx).Set(reflect.Zero(field.Type))
}
} else {
diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: fmt.Sprintf("Missing %s block", typeName),
Detail: fmt.Sprintf("A %s block is required.", typeName),
Subject: body.MissingItemRange().Ptr(),
})
}
continue
}
switch {
case isSlice:
elemType := ty
if isPtr {
elemType = reflect.PtrTo(ty)
}
sli := val.Field(fieldIdx)
if sli.IsNil() {
sli = reflect.MakeSlice(reflect.SliceOf(elemType), len(blocks), len(blocks))
}
for i, block := range blocks {
if isPtr {
if i >= sli.Len() {
sli = reflect.Append(sli, reflect.New(ty))
}
v := sli.Index(i)
if v.IsNil() {
v = reflect.New(ty)
}
diags = append(diags, decodeBlockToValue(block, ctx, v.Elem())...)
sli.Index(i).Set(v)
} else {
if i >= sli.Len() {
sli = reflect.Append(sli, reflect.Indirect(reflect.New(ty)))
}
diags = append(diags, decodeBlockToValue(block, ctx, sli.Index(i))...)
}
}
if sli.Len() > len(blocks) {
sli.SetLen(len(blocks))
}
val.Field(fieldIdx).Set(sli)
default:
block := blocks[0]
if isPtr {
v := val.Field(fieldIdx)
if v.IsNil() {
v = reflect.New(ty)
}
diags = append(diags, decodeBlockToValue(block, ctx, v.Elem())...)
val.Field(fieldIdx).Set(v)
} else {
diags = append(diags, decodeBlockToValue(block, ctx, val.Field(fieldIdx))...)
}
}
}
return diags
}
func decodeBodyToMap(body hcl.Body, ctx *hcl.EvalContext, v reflect.Value) hcl.Diagnostics {
attrs, diags := body.JustAttributes()
if attrs == nil {
return diags
}
mv := reflect.MakeMap(v.Type())
for k, attr := range attrs {
switch {
case attrType.AssignableTo(v.Type().Elem()):
mv.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(attr))
case exprType.AssignableTo(v.Type().Elem()):
mv.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(attr.Expr))
default:
ev := reflect.New(v.Type().Elem())
diags = append(diags, DecodeExpression(attr.Expr, ctx, ev.Interface())...)
mv.SetMapIndex(reflect.ValueOf(k), ev.Elem())
}
}
v.Set(mv)
return diags
}
func decodeBlockToValue(block *hcl.Block, ctx *hcl.EvalContext, v reflect.Value) hcl.Diagnostics {
diags := decodeBodyToValue(block.Body, ctx, v)
if len(block.Labels) > 0 {
blockTags := getFieldTags(v.Type())
for li, lv := range block.Labels {
lfieldIdx := blockTags.Labels[li].FieldIndex
v.Field(lfieldIdx).Set(reflect.ValueOf(lv))
}
}
return diags
}
// DecodeExpression extracts the value of the given expression into the given
// value. This value must be something that gocty is able to decode into,
// since the final decoding is delegated to that package. If a reference to
// a struct is provided which contains gohcl tags, it will be decoded using
// the attr and optional tags.
//
// The given EvalContext is used to resolve any variables or functions in
// expressions encountered while decoding. This may be nil to require only
// constant values, for simple applications that do not support variables or
// functions.
//
// The returned diagnostics should be inspected with its HasErrors method to
// determine if the populated value is valid and complete. If error diagnostics
// are returned then the given value may have been partially-populated but
// may still be accessed by a careful caller for static analysis and editor
// integration use-cases.
func DecodeExpression(expr hcl.Expression, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics {
srcVal, diags := expr.Value(ctx)
if diags.HasErrors() {
return diags
}
return append(diags, DecodeValue(srcVal, expr.StartRange(), expr.Range(), val)...)
}
// DecodeValue extracts the given value into the provided target.
// This value must be something that gocty is able to decode into,
// since the final decoding is delegated to that package. If a reference to
// a struct is provided which contains gohcl tags, it will be decoded using
// the attr and optional tags.
//
// The returned diagnostics should be inspected with its HasErrors method to
// determine if the populated value is valid and complete. If error diagnostics
// are returned then the given value may have been partially-populated but
// may still be accessed by a careful caller for static analysis and editor
// integration use-cases.
func DecodeValue(srcVal cty.Value, subject hcl.Range, context hcl.Range, val interface{}) hcl.Diagnostics {
rv := reflect.ValueOf(val)
if rv.Type().Kind() == reflect.Ptr && rv.Type().Elem().Kind() == reflect.Struct && hasFieldTags(rv.Elem().Type()) {
attrs := make(hcl.Attributes)
for k, v := range srcVal.AsValueMap() {
attrs[k] = &hcl.Attribute{
Name: k,
Expr: hcl.StaticExpr(v, context),
Range: subject,
}
}
return decodeBodyToStruct(synthBody{
attrs: attrs,
subject: subject,
context: context,
}, nil, rv.Elem())
}
convTy, err := gocty.ImpliedType(val)
if err != nil {
panic(fmt.Sprintf("unsuitable DecodeExpression target: %s", err))
}
var diags hcl.Diagnostics
srcVal, err = convert.Convert(srcVal, convTy)
if err != nil {
diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "Unsuitable value type",
Detail: fmt.Sprintf("Unsuitable value: %s", err.Error()),
Subject: subject.Ptr(),
Context: context.Ptr(),
})
return diags
}
err = gocty.FromCtyValue(srcVal, val)
if err != nil {
diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "Unsuitable value type",
Detail: fmt.Sprintf("Unsuitable value: %s", err.Error()),
Subject: subject.Ptr(),
Context: context.Ptr(),
})
}
return diags
}
type synthBody struct {
attrs hcl.Attributes
subject hcl.Range
context hcl.Range
}
func (s synthBody) Content(schema *hcl.BodySchema) (*hcl.BodyContent, hcl.Diagnostics) {
body, partial, diags := s.PartialContent(schema)
attrs, _ := partial.JustAttributes()
for name := range attrs {
diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "Unsupported argument",
Detail: fmt.Sprintf("An argument named %q is not expected here.", name),
Subject: s.subject.Ptr(),
Context: s.context.Ptr(),
})
}
return body, diags
}
func (s synthBody) PartialContent(schema *hcl.BodySchema) (*hcl.BodyContent, hcl.Body, hcl.Diagnostics) {
var diags hcl.Diagnostics
for _, block := range schema.Blocks {
panic("hcl block tags are not allowed in attribute structs: " + block.Type)
}
attrs := make(hcl.Attributes)
remainder := make(hcl.Attributes)
for _, attr := range schema.Attributes {
v, ok := s.attrs[attr.Name]
if !ok {
if attr.Required {
diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "Missing required argument",
Detail: fmt.Sprintf("The argument %q is required, but no definition was found.", attr.Name),
Subject: s.subject.Ptr(),
Context: s.context.Ptr(),
})
}
continue
}
attrs[attr.Name] = v
}
for k, v := range s.attrs {
if _, ok := attrs[k]; !ok {
remainder[k] = v
}
}
return &hcl.BodyContent{
Attributes: attrs,
MissingItemRange: s.context,
}, synthBody{attrs: remainder}, diags
}
func (s synthBody) JustAttributes() (hcl.Attributes, hcl.Diagnostics) {
return s.attrs, nil
}
func (s synthBody) MissingItemRange() hcl.Range {
return s.context
}

View File

@ -0,0 +1,813 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package gohcl
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/hashicorp/hcl/v2"
hclJSON "github.com/hashicorp/hcl/v2/json"
"github.com/zclconf/go-cty/cty"
)
func TestDecodeBody(t *testing.T) {
deepEquals := func(other interface{}) func(v interface{}) bool {
return func(v interface{}) bool {
return reflect.DeepEqual(v, other)
}
}
type withNameExpression struct {
Name hcl.Expression `hcl:"name"`
}
type withTwoAttributes struct {
A string `hcl:"a,optional"`
B string `hcl:"b,optional"`
}
type withNestedBlock struct {
Plain string `hcl:"plain,optional"`
Nested *withTwoAttributes `hcl:"nested,block"`
}
type withListofNestedBlocks struct {
Nested []*withTwoAttributes `hcl:"nested,block"`
}
type withListofNestedBlocksNoPointers struct {
Nested []withTwoAttributes `hcl:"nested,block"`
}
tests := []struct {
Body map[string]interface{}
Target func() interface{}
Check func(v interface{}) bool
DiagCount int
}{
{
map[string]interface{}{},
makeInstantiateType(struct{}{}),
deepEquals(struct{}{}),
0,
},
{
map[string]interface{}{},
makeInstantiateType(struct {
Name string `hcl:"name"`
}{}),
deepEquals(struct {
Name string `hcl:"name"`
}{}),
1, // name is required
},
{
map[string]interface{}{},
makeInstantiateType(struct {
Name *string `hcl:"name"`
}{}),
deepEquals(struct {
Name *string `hcl:"name"`
}{}),
0,
}, // name nil
{
map[string]interface{}{},
makeInstantiateType(struct {
Name string `hcl:"name,optional"`
}{}),
deepEquals(struct {
Name string `hcl:"name,optional"`
}{}),
0,
}, // name optional
{
map[string]interface{}{},
makeInstantiateType(withNameExpression{}),
func(v interface{}) bool {
if v == nil {
return false
}
wne, valid := v.(withNameExpression)
if !valid {
return false
}
if wne.Name == nil {
return false
}
nameVal, _ := wne.Name.Value(nil)
if !nameVal.IsNull() {
return false
}
return true
},
0,
},
{
map[string]interface{}{
"name": "Ermintrude",
},
makeInstantiateType(withNameExpression{}),
func(v interface{}) bool {
if v == nil {
return false
}
wne, valid := v.(withNameExpression)
if !valid {
return false
}
if wne.Name == nil {
return false
}
nameVal, _ := wne.Name.Value(nil)
if !nameVal.Equals(cty.StringVal("Ermintrude")).True() {
return false
}
return true
},
0,
},
{
map[string]interface{}{
"name": "Ermintrude",
},
makeInstantiateType(struct {
Name string `hcl:"name"`
}{}),
deepEquals(struct {
Name string `hcl:"name"`
}{"Ermintrude"}),
0,
},
{
map[string]interface{}{
"name": "Ermintrude",
"age": 23,
},
makeInstantiateType(struct {
Name string `hcl:"name"`
}{}),
deepEquals(struct {
Name string `hcl:"name"`
}{"Ermintrude"}),
1, // Extraneous "age" property
},
{
map[string]interface{}{
"name": "Ermintrude",
"age": 50,
},
makeInstantiateType(struct {
Name string `hcl:"name"`
Attrs hcl.Attributes `hcl:",remain"`
}{}),
func(gotI interface{}) bool {
got := gotI.(struct {
Name string `hcl:"name"`
Attrs hcl.Attributes `hcl:",remain"`
})
return got.Name == "Ermintrude" && len(got.Attrs) == 1 && got.Attrs["age"] != nil
},
0,
},
{
map[string]interface{}{
"name": "Ermintrude",
"age": 50,
},
makeInstantiateType(struct {
Name string `hcl:"name"`
Remain hcl.Body `hcl:",remain"`
}{}),
func(gotI interface{}) bool {
got := gotI.(struct {
Name string `hcl:"name"`
Remain hcl.Body `hcl:",remain"`
})
attrs, _ := got.Remain.JustAttributes()
return got.Name == "Ermintrude" && len(attrs) == 1 && attrs["age"] != nil
},
0,
},
{
map[string]interface{}{
"name": "Ermintrude",
"living": true,
},
makeInstantiateType(struct {
Name string `hcl:"name"`
Remain map[string]cty.Value `hcl:",remain"`
}{}),
deepEquals(struct {
Name string `hcl:"name"`
Remain map[string]cty.Value `hcl:",remain"`
}{
Name: "Ermintrude",
Remain: map[string]cty.Value{
"living": cty.True,
},
}),
0,
},
{
map[string]interface{}{
"name": "Ermintrude",
"age": 50,
},
makeInstantiateType(struct {
Name string `hcl:"name"`
Body hcl.Body `hcl:",body"`
Remain hcl.Body `hcl:",remain"`
}{}),
func(gotI interface{}) bool {
got := gotI.(struct {
Name string `hcl:"name"`
Body hcl.Body `hcl:",body"`
Remain hcl.Body `hcl:",remain"`
})
attrs, _ := got.Body.JustAttributes()
return got.Name == "Ermintrude" && len(attrs) == 2 &&
attrs["name"] != nil && attrs["age"] != nil
},
0,
},
{
map[string]interface{}{
"noodle": map[string]interface{}{},
},
makeInstantiateType(struct {
Noodle struct{} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
// Generating no diagnostics is good enough for this one.
return true
},
0,
},
{
map[string]interface{}{
"noodle": []map[string]interface{}{{}},
},
makeInstantiateType(struct {
Noodle struct{} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
// Generating no diagnostics is good enough for this one.
return true
},
0,
},
{
map[string]interface{}{
"noodle": []map[string]interface{}{{}, {}},
},
makeInstantiateType(struct {
Noodle struct{} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
// Generating one diagnostic is good enough for this one.
return true
},
1,
},
{
map[string]interface{}{},
makeInstantiateType(struct {
Noodle struct{} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
// Generating one diagnostic is good enough for this one.
return true
},
1,
},
{
map[string]interface{}{
"noodle": []map[string]interface{}{},
},
makeInstantiateType(struct {
Noodle struct{} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
// Generating one diagnostic is good enough for this one.
return true
},
1,
},
{
map[string]interface{}{
"noodle": map[string]interface{}{},
},
makeInstantiateType(struct {
Noodle *struct{} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
return gotI.(struct {
Noodle *struct{} `hcl:"noodle,block"`
}).Noodle != nil
},
0,
},
{
map[string]interface{}{
"noodle": []map[string]interface{}{{}},
},
makeInstantiateType(struct {
Noodle *struct{} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
return gotI.(struct {
Noodle *struct{} `hcl:"noodle,block"`
}).Noodle != nil
},
0,
},
{
map[string]interface{}{
"noodle": []map[string]interface{}{},
},
makeInstantiateType(struct {
Noodle *struct{} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
return gotI.(struct {
Noodle *struct{} `hcl:"noodle,block"`
}).Noodle == nil
},
0,
},
{
map[string]interface{}{
"noodle": []map[string]interface{}{{}, {}},
},
makeInstantiateType(struct {
Noodle *struct{} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
// Generating one diagnostic is good enough for this one.
return true
},
1,
},
{
map[string]interface{}{
"noodle": []map[string]interface{}{},
},
makeInstantiateType(struct {
Noodle []struct{} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
noodle := gotI.(struct {
Noodle []struct{} `hcl:"noodle,block"`
}).Noodle
return len(noodle) == 0
},
0,
},
{
map[string]interface{}{
"noodle": []map[string]interface{}{{}},
},
makeInstantiateType(struct {
Noodle []struct{} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
noodle := gotI.(struct {
Noodle []struct{} `hcl:"noodle,block"`
}).Noodle
return len(noodle) == 1
},
0,
},
{
map[string]interface{}{
"noodle": []map[string]interface{}{{}, {}},
},
makeInstantiateType(struct {
Noodle []struct{} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
noodle := gotI.(struct {
Noodle []struct{} `hcl:"noodle,block"`
}).Noodle
return len(noodle) == 2
},
0,
},
{
map[string]interface{}{
"noodle": map[string]interface{}{},
},
makeInstantiateType(struct {
Noodle struct {
Name string `hcl:"name,label"`
} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
// Generating two diagnostics is good enough for this one.
// (one for the missing noodle block and the other for
// the JSON serialization detecting the missing level of
// heirarchy for the label.)
return true
},
2,
},
{
map[string]interface{}{
"noodle": map[string]interface{}{
"foo_foo": map[string]interface{}{},
},
},
makeInstantiateType(struct {
Noodle struct {
Name string `hcl:"name,label"`
} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
noodle := gotI.(struct {
Noodle struct {
Name string `hcl:"name,label"`
} `hcl:"noodle,block"`
}).Noodle
return noodle.Name == "foo_foo"
},
0,
},
{
map[string]interface{}{
"noodle": map[string]interface{}{
"foo_foo": map[string]interface{}{},
"bar_baz": map[string]interface{}{},
},
},
makeInstantiateType(struct {
Noodle struct {
Name string `hcl:"name,label"`
} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
// One diagnostic is enough for this one.
return true
},
1,
},
{
map[string]interface{}{
"noodle": map[string]interface{}{
"foo_foo": map[string]interface{}{},
"bar_baz": map[string]interface{}{},
},
},
makeInstantiateType(struct {
Noodles []struct {
Name string `hcl:"name,label"`
} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
noodles := gotI.(struct {
Noodles []struct {
Name string `hcl:"name,label"`
} `hcl:"noodle,block"`
}).Noodles
return len(noodles) == 2 && (noodles[0].Name == "foo_foo" || noodles[0].Name == "bar_baz") && (noodles[1].Name == "foo_foo" || noodles[1].Name == "bar_baz") && noodles[0].Name != noodles[1].Name
},
0,
},
{
map[string]interface{}{
"noodle": map[string]interface{}{
"foo_foo": map[string]interface{}{
"type": "rice",
},
},
},
makeInstantiateType(struct {
Noodle struct {
Name string `hcl:"name,label"`
Type string `hcl:"type"`
} `hcl:"noodle,block"`
}{}),
func(gotI interface{}) bool {
noodle := gotI.(struct {
Noodle struct {
Name string `hcl:"name,label"`
Type string `hcl:"type"`
} `hcl:"noodle,block"`
}).Noodle
return noodle.Name == "foo_foo" && noodle.Type == "rice"
},
0,
},
{
map[string]interface{}{
"name": "Ermintrude",
"age": 34,
},
makeInstantiateType(map[string]string(nil)),
deepEquals(map[string]string{
"name": "Ermintrude",
"age": "34",
}),
0,
},
{
map[string]interface{}{
"name": "Ermintrude",
"age": 89,
},
makeInstantiateType(map[string]*hcl.Attribute(nil)),
func(gotI interface{}) bool {
got := gotI.(map[string]*hcl.Attribute)
return len(got) == 2 && got["name"] != nil && got["age"] != nil
},
0,
},
{
map[string]interface{}{
"name": "Ermintrude",
"age": 13,
},
makeInstantiateType(map[string]hcl.Expression(nil)),
func(gotI interface{}) bool {
got := gotI.(map[string]hcl.Expression)
return len(got) == 2 && got["name"] != nil && got["age"] != nil
},
0,
},
{
map[string]interface{}{
"name": "Ermintrude",
"living": true,
},
makeInstantiateType(map[string]cty.Value(nil)),
deepEquals(map[string]cty.Value{
"name": cty.StringVal("Ermintrude"),
"living": cty.True,
}),
0,
},
{
// Retain "nested" block while decoding
map[string]interface{}{
"plain": "foo",
},
func() interface{} {
return &withNestedBlock{
Plain: "bar",
Nested: &withTwoAttributes{
A: "bar",
},
}
},
func(gotI interface{}) bool {
foo := gotI.(withNestedBlock)
return foo.Plain == "foo" && foo.Nested != nil && foo.Nested.A == "bar"
},
0,
},
{
// Retain values in "nested" block while decoding
map[string]interface{}{
"nested": map[string]interface{}{
"a": "foo",
},
},
func() interface{} {
return &withNestedBlock{
Nested: &withTwoAttributes{
B: "bar",
},
}
},
func(gotI interface{}) bool {
foo := gotI.(withNestedBlock)
return foo.Nested.A == "foo" && foo.Nested.B == "bar"
},
0,
},
{
// Retain values in "nested" block list while decoding
map[string]interface{}{
"nested": []map[string]interface{}{
{
"a": "foo",
},
},
},
func() interface{} {
return &withListofNestedBlocks{
Nested: []*withTwoAttributes{
&withTwoAttributes{
B: "bar",
},
},
}
},
func(gotI interface{}) bool {
n := gotI.(withListofNestedBlocks)
return n.Nested[0].A == "foo" && n.Nested[0].B == "bar"
},
0,
},
{
// Remove additional elements from the list while decoding nested blocks
map[string]interface{}{
"nested": []map[string]interface{}{
{
"a": "foo",
},
},
},
func() interface{} {
return &withListofNestedBlocks{
Nested: []*withTwoAttributes{
&withTwoAttributes{
B: "bar",
},
&withTwoAttributes{
B: "bar",
},
},
}
},
func(gotI interface{}) bool {
n := gotI.(withListofNestedBlocks)
return len(n.Nested) == 1
},
0,
},
{
// Make sure decoding value slices works the same as pointer slices.
map[string]interface{}{
"nested": []map[string]interface{}{
{
"b": "bar",
},
{
"b": "baz",
},
},
},
func() interface{} {
return &withListofNestedBlocksNoPointers{
Nested: []withTwoAttributes{
{
B: "foo",
},
},
}
},
func(gotI interface{}) bool {
n := gotI.(withListofNestedBlocksNoPointers)
return n.Nested[0].B == "bar" && len(n.Nested) == 2
},
0,
},
}
for i, test := range tests {
// For convenience here we're going to use the JSON parser
// to process the given body.
buf, err := json.Marshal(test.Body)
if err != nil {
t.Fatalf("error JSON-encoding body for test %d: %s", i, err)
}
t.Run(string(buf), func(t *testing.T) {
file, diags := hclJSON.Parse(buf, "test.json")
if len(diags) != 0 {
t.Fatalf("diagnostics while parsing: %s", diags.Error())
}
targetVal := reflect.ValueOf(test.Target())
diags = DecodeBody(file.Body, nil, targetVal.Interface())
if len(diags) != test.DiagCount {
t.Errorf("wrong number of diagnostics %d; want %d", len(diags), test.DiagCount)
for _, diag := range diags {
t.Logf(" - %s", diag.Error())
}
}
got := targetVal.Elem().Interface()
if !test.Check(got) {
t.Errorf("wrong result\ngot: %s", spew.Sdump(got))
}
})
}
}
func TestDecodeExpression(t *testing.T) {
tests := []struct {
Value cty.Value
Target interface{}
Want interface{}
DiagCount int
}{
{
cty.StringVal("hello"),
"",
"hello",
0,
},
{
cty.StringVal("hello"),
cty.NilVal,
cty.StringVal("hello"),
0,
},
{
cty.NumberIntVal(2),
"",
"2",
0,
},
{
cty.StringVal("true"),
false,
true,
0,
},
{
cty.NullVal(cty.String),
"",
"",
1, // null value is not allowed
},
{
cty.UnknownVal(cty.String),
"",
"",
1, // value must be known
},
{
cty.ListVal([]cty.Value{cty.True}),
false,
false,
1, // bool required
},
}
for i, test := range tests {
t.Run(fmt.Sprintf("%02d", i), func(t *testing.T) {
expr := &fixedExpression{test.Value}
targetVal := reflect.New(reflect.TypeOf(test.Target))
diags := DecodeExpression(expr, nil, targetVal.Interface())
if len(diags) != test.DiagCount {
t.Errorf("wrong number of diagnostics %d; want %d", len(diags), test.DiagCount)
for _, diag := range diags {
t.Logf(" - %s", diag.Error())
}
}
got := targetVal.Elem().Interface()
if !reflect.DeepEqual(got, test.Want) {
t.Errorf("wrong result\ngot: %#v\nwant: %#v", got, test.Want)
}
})
}
}
type fixedExpression struct {
val cty.Value
}
func (e *fixedExpression) Value(ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) {
return e.val, nil
}
func (e *fixedExpression) Range() (r hcl.Range) {
return
}
func (e *fixedExpression) StartRange() (r hcl.Range) {
return
}
func (e *fixedExpression) Variables() []hcl.Traversal {
return nil
}
func makeInstantiateType(target interface{}) func() interface{} {
return func() interface{} {
return reflect.New(reflect.TypeOf(target)).Interface()
}
}

65
internal/gohcl/doc.go Normal file
View File

@ -0,0 +1,65 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
// Package gohcl allows decoding HCL configurations into Go data structures.
//
// It provides a convenient and concise way of describing the schema for
// configuration and then accessing the resulting data via native Go
// types.
//
// A struct field tag scheme is used, similar to other decoding and
// unmarshalling libraries. The tags are formatted as in the following example:
//
// ThingType string `hcl:"thing_type,attr"`
//
// Within each tag there are two comma-separated tokens. The first is the
// name of the corresponding construct in configuration, while the second
// is a keyword giving the kind of construct expected. The following
// kind keywords are supported:
//
// attr (the default) indicates that the value is to be populated from an attribute
// block indicates that the value is to populated from a block
// label indicates that the value is to populated from a block label
// optional is the same as attr, but the field is optional
// remain indicates that the value is to be populated from the remaining body after populating other fields
//
// "attr" fields may either be of type *hcl.Expression, in which case the raw
// expression is assigned, or of any type accepted by gocty, in which case
// gocty will be used to assign the value to a native Go type.
//
// "block" fields may be a struct that recursively uses the same tags, or a
// slice of such structs, in which case multiple blocks of the corresponding
// type are decoded into the slice.
//
// "body" can be placed on a single field of type hcl.Body to capture
// the full hcl.Body that was decoded for a block. This does not allow leftover
// values like "remain", so a decoding error will still be returned if leftover
// fields are given. If you want to capture the decoding body PLUS leftover
// fields, you must specify a "remain" field as well to prevent errors. The
// body field and the remain field will both contain the leftover fields.
//
// "label" fields are considered only in a struct used as the type of a field
// marked as "block", and are used sequentially to capture the labels of
// the blocks being decoded. In this case, the name token is used only as
// an identifier for the label in diagnostic messages.
//
// "optional" fields behave like "attr" fields, but they are optional
// and will not give parsing errors if they are missing.
//
// "remain" can be placed on a single field that may be either of type
// hcl.Body or hcl.Attributes, in which case any remaining body content is
// placed into this field for delayed processing. If no "remain" field is
// present then any attributes or blocks not matched by another valid tag
// will cause an error diagnostic.
//
// Only a subset of this tagging/typing vocabulary is supported for the
// "Encode" family of functions. See the EncodeIntoBody docs for full details
// on the constraints there.
//
// Broadly-speaking this package deals with two types of error. The first is
// errors in the configuration itself, which are returned as diagnostics
// written with the configuration author as the target audience. The second
// is bugs in the calling program, such as invalid struct tags, which are
// surfaced via panics since there can be no useful runtime handling of such
// errors and they should certainly not be returned to the user as diagnostics.
package gohcl

194
internal/gohcl/encode.go Normal file
View File

@ -0,0 +1,194 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package gohcl
import (
"fmt"
"reflect"
"sort"
"github.com/hashicorp/hcl/v2/hclwrite"
"github.com/zclconf/go-cty/cty/gocty"
)
// EncodeIntoBody replaces the contents of the given hclwrite Body with
// attributes and blocks derived from the given value, which must be a
// struct value or a pointer to a struct value with the struct tags defined
// in this package.
//
// This function can work only with fully-decoded data. It will ignore any
// fields tagged as "remain", any fields that decode attributes into either
// hcl.Attribute or hcl.Expression values, and any fields that decode blocks
// into hcl.Attributes values. This function does not have enough information
// to complete the decoding of these types.
//
// Any fields tagged as "label" are ignored by this function. Use EncodeAsBlock
// to produce a whole hclwrite.Block including block labels.
//
// As long as a suitable value is given to encode and the destination body
// is non-nil, this function will always complete. It will panic in case of
// any errors in the calling program, such as passing an inappropriate type
// or a nil body.
//
// The layout of the resulting HCL source is derived from the ordering of
// the struct fields, with blank lines around nested blocks of different types.
// Fields representing attributes should usually precede those representing
// blocks so that the attributes can group togather in the result. For more
// control, use the hclwrite API directly.
func EncodeIntoBody(val interface{}, dst *hclwrite.Body) {
rv := reflect.ValueOf(val)
ty := rv.Type()
if ty.Kind() == reflect.Ptr {
rv = rv.Elem()
ty = rv.Type()
}
if ty.Kind() != reflect.Struct {
panic(fmt.Sprintf("value is %s, not struct", ty.Kind()))
}
tags := getFieldTags(ty)
populateBody(rv, ty, tags, dst)
}
// EncodeAsBlock creates a new hclwrite.Block populated with the data from
// the given value, which must be a struct or pointer to struct with the
// struct tags defined in this package.
//
// If the given struct type has fields tagged with "label" tags then they
// will be used in order to annotate the created block with labels.
//
// This function has the same constraints as EncodeIntoBody and will panic
// if they are violated.
func EncodeAsBlock(val interface{}, blockType string) *hclwrite.Block {
rv := reflect.ValueOf(val)
ty := rv.Type()
if ty.Kind() == reflect.Ptr {
rv = rv.Elem()
ty = rv.Type()
}
if ty.Kind() != reflect.Struct {
panic(fmt.Sprintf("value is %s, not struct", ty.Kind()))
}
tags := getFieldTags(ty)
labels := make([]string, len(tags.Labels))
for i, lf := range tags.Labels {
lv := rv.Field(lf.FieldIndex)
// We just stringify whatever we find. It should always be a string
// but if not then we'll still do something reasonable.
labels[i] = fmt.Sprintf("%s", lv.Interface())
}
block := hclwrite.NewBlock(blockType, labels)
populateBody(rv, ty, tags, block.Body())
return block
}
func populateBody(rv reflect.Value, ty reflect.Type, tags *fieldTags, dst *hclwrite.Body) {
nameIdxs := make(map[string]int, len(tags.Attributes)+len(tags.Blocks))
namesOrder := make([]string, 0, len(tags.Attributes)+len(tags.Blocks))
for n, i := range tags.Attributes {
nameIdxs[n] = i
namesOrder = append(namesOrder, n)
}
for n, i := range tags.Blocks {
nameIdxs[n] = i
namesOrder = append(namesOrder, n)
}
sort.SliceStable(namesOrder, func(i, j int) bool {
ni, nj := namesOrder[i], namesOrder[j]
return nameIdxs[ni] < nameIdxs[nj]
})
dst.Clear()
prevWasBlock := false
for _, name := range namesOrder {
fieldIdx := nameIdxs[name]
field := ty.Field(fieldIdx)
fieldTy := field.Type
fieldVal := rv.Field(fieldIdx)
if fieldTy.Kind() == reflect.Ptr {
fieldTy = fieldTy.Elem()
fieldVal = fieldVal.Elem()
}
if _, isAttr := tags.Attributes[name]; isAttr {
if exprType.AssignableTo(fieldTy) || attrType.AssignableTo(fieldTy) {
continue // ignore undecoded fields
}
if !fieldVal.IsValid() {
continue // ignore (field value is nil pointer)
}
if fieldTy.Kind() == reflect.Ptr && fieldVal.IsNil() {
continue // ignore
}
if prevWasBlock {
dst.AppendNewline()
prevWasBlock = false
}
valTy, err := gocty.ImpliedType(fieldVal.Interface())
if err != nil {
panic(fmt.Sprintf("cannot encode %T as HCL expression: %s", fieldVal.Interface(), err))
}
val, err := gocty.ToCtyValue(fieldVal.Interface(), valTy)
if err != nil {
// This should never happen, since we should always be able
// to decode into the implied type.
panic(fmt.Sprintf("failed to encode %T as %#v: %s", fieldVal.Interface(), valTy, err))
}
dst.SetAttributeValue(name, val)
} else { // must be a block, then
elemTy := fieldTy
isSeq := false
if elemTy.Kind() == reflect.Slice || elemTy.Kind() == reflect.Array {
isSeq = true
elemTy = elemTy.Elem()
}
if bodyType.AssignableTo(elemTy) || attrsType.AssignableTo(elemTy) {
continue // ignore undecoded fields
}
prevWasBlock = false
if isSeq {
l := fieldVal.Len()
for i := 0; i < l; i++ {
elemVal := fieldVal.Index(i)
if !elemVal.IsValid() {
continue // ignore (elem value is nil pointer)
}
if elemTy.Kind() == reflect.Ptr && elemVal.IsNil() {
continue // ignore
}
block := EncodeAsBlock(elemVal.Interface(), name)
if !prevWasBlock {
dst.AppendNewline()
prevWasBlock = true
}
dst.AppendBlock(block)
}
} else {
if !fieldVal.IsValid() {
continue // ignore (field value is nil pointer)
}
if elemTy.Kind() == reflect.Ptr && fieldVal.IsNil() {
continue // ignore
}
block := EncodeAsBlock(fieldVal.Interface(), name)
if !prevWasBlock {
dst.AppendNewline()
prevWasBlock = true
}
dst.AppendBlock(block)
}
}
}
}

View File

@ -0,0 +1,67 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package gohcl_test
import (
"fmt"
"github.com/hashicorp/hcl/v2/gohcl"
"github.com/hashicorp/hcl/v2/hclwrite"
)
func ExampleEncodeIntoBody() {
type Service struct {
Name string `hcl:"name,label"`
Exe []string `hcl:"executable"`
}
type Constraints struct {
OS string `hcl:"os"`
Arch string `hcl:"arch"`
}
type App struct {
Name string `hcl:"name"`
Desc string `hcl:"description"`
Constraints *Constraints `hcl:"constraints,block"`
Services []Service `hcl:"service,block"`
}
app := App{
Name: "awesome-app",
Desc: "Such an awesome application",
Constraints: &Constraints{
OS: "linux",
Arch: "amd64",
},
Services: []Service{
{
Name: "web",
Exe: []string{"./web", "--listen=:8080"},
},
{
Name: "worker",
Exe: []string{"./worker"},
},
},
}
f := hclwrite.NewEmptyFile()
gohcl.EncodeIntoBody(&app, f.Body())
fmt.Printf("%s", f.Bytes())
// Output:
// name = "awesome-app"
// description = "Such an awesome application"
//
// constraints {
// os = "linux"
// arch = "amd64"
// }
//
// service "web" {
// executable = ["./web", "--listen=:8080"]
// }
// service "worker" {
// executable = ["./worker"]
// }
}

196
internal/gohcl/schema.go Normal file
View File

@ -0,0 +1,196 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package gohcl
import (
"fmt"
"reflect"
"sort"
"strings"
"github.com/hashicorp/hcl/v2"
)
// ImpliedBodySchema produces a hcl.BodySchema derived from the type of the
// given value, which must be a struct value or a pointer to one. If an
// inappropriate value is passed, this function will panic.
//
// The second return argument indicates whether the given struct includes
// a "remain" field, and thus the returned schema is non-exhaustive.
//
// This uses the tags on the fields of the struct to discover how each
// field's value should be expressed within configuration. If an invalid
// mapping is attempted, this function will panic.
func ImpliedBodySchema(val interface{}) (schema *hcl.BodySchema, partial bool) {
ty := reflect.TypeOf(val)
if ty.Kind() == reflect.Ptr {
ty = ty.Elem()
}
if ty.Kind() != reflect.Struct {
panic(fmt.Sprintf("given value must be struct, not %T", val))
}
var attrSchemas []hcl.AttributeSchema
var blockSchemas []hcl.BlockHeaderSchema
tags := getFieldTags(ty)
attrNames := make([]string, 0, len(tags.Attributes))
for n := range tags.Attributes {
attrNames = append(attrNames, n)
}
sort.Strings(attrNames)
for _, n := range attrNames {
idx := tags.Attributes[n]
optional := tags.Optional[n]
field := ty.Field(idx)
var required bool
switch {
case field.Type.AssignableTo(exprType):
// If we're decoding to hcl.Expression then absense can be
// indicated via a null value, so we don't specify that
// the field is required during decoding.
required = false
case field.Type.Kind() != reflect.Ptr && !optional:
required = true
default:
required = false
}
attrSchemas = append(attrSchemas, hcl.AttributeSchema{
Name: n,
Required: required,
})
}
blockNames := make([]string, 0, len(tags.Blocks))
for n := range tags.Blocks {
blockNames = append(blockNames, n)
}
sort.Strings(blockNames)
for _, n := range blockNames {
idx := tags.Blocks[n]
field := ty.Field(idx)
fty := field.Type
if fty.Kind() == reflect.Slice {
fty = fty.Elem()
}
if fty.Kind() == reflect.Ptr {
fty = fty.Elem()
}
if fty.Kind() != reflect.Struct {
panic(fmt.Sprintf(
"hcl 'block' tag kind cannot be applied to %s field %s: struct required", field.Type.String(), field.Name,
))
}
ftags := getFieldTags(fty)
var labelNames []string
if len(ftags.Labels) > 0 {
labelNames = make([]string, len(ftags.Labels))
for i, l := range ftags.Labels {
labelNames[i] = l.Name
}
}
blockSchemas = append(blockSchemas, hcl.BlockHeaderSchema{
Type: n,
LabelNames: labelNames,
})
}
partial = tags.Remain != nil
schema = &hcl.BodySchema{
Attributes: attrSchemas,
Blocks: blockSchemas,
}
return schema, partial
}
func hasFieldTags(ty reflect.Type) bool {
ct := ty.NumField()
for i := 0; i < ct; i++ {
field := ty.Field(i)
tag := field.Tag.Get("hcl")
if tag != "" {
return true
}
}
return false
}
type fieldTags struct {
Attributes map[string]int
Blocks map[string]int
Labels []labelField
Remain *int
Body *int
Optional map[string]bool
}
type labelField struct {
FieldIndex int
Name string
}
func getFieldTags(ty reflect.Type) *fieldTags {
ret := &fieldTags{
Attributes: map[string]int{},
Blocks: map[string]int{},
Optional: map[string]bool{},
}
ct := ty.NumField()
for i := 0; i < ct; i++ {
field := ty.Field(i)
tag := field.Tag.Get("hcl")
if tag == "" {
continue
}
comma := strings.Index(tag, ",")
var name, kind string
if comma != -1 {
name = tag[:comma]
kind = tag[comma+1:]
} else {
name = tag
kind = "attr"
}
switch kind {
case "attr":
ret.Attributes[name] = i
case "block":
ret.Blocks[name] = i
case "label":
ret.Labels = append(ret.Labels, labelField{
FieldIndex: i,
Name: name,
})
case "remain":
if ret.Remain != nil {
panic("only one 'remain' tag is permitted")
}
idx := i // copy, because this loop will continue assigning to i
ret.Remain = &idx
case "body":
if ret.Body != nil {
panic("only one 'body' tag is permitted")
}
idx := i // copy, because this loop will continue assigning to i
ret.Body = &idx
case "optional":
ret.Attributes[name] = i
ret.Optional[name] = true
default:
panic(fmt.Sprintf("invalid hcl field tag kind %q on %s %q", kind, field.Type.String(), field.Name))
}
}
return ret
}

View File

@ -0,0 +1,233 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package gohcl
import (
"fmt"
"reflect"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/hashicorp/hcl/v2"
)
func TestImpliedBodySchema(t *testing.T) {
tests := []struct {
val interface{}
wantSchema *hcl.BodySchema
wantPartial bool
}{
{
struct{}{},
&hcl.BodySchema{},
false,
},
{
struct {
Ignored bool
}{},
&hcl.BodySchema{},
false,
},
{
struct {
Attr1 bool `hcl:"attr1"`
Attr2 bool `hcl:"attr2"`
}{},
&hcl.BodySchema{
Attributes: []hcl.AttributeSchema{
{
Name: "attr1",
Required: true,
},
{
Name: "attr2",
Required: true,
},
},
},
false,
},
{
struct {
Attr *bool `hcl:"attr,attr"`
}{},
&hcl.BodySchema{
Attributes: []hcl.AttributeSchema{
{
Name: "attr",
Required: false,
},
},
},
false,
},
{
struct {
Thing struct{} `hcl:"thing,block"`
}{},
&hcl.BodySchema{
Blocks: []hcl.BlockHeaderSchema{
{
Type: "thing",
},
},
},
false,
},
{
struct {
Thing struct {
Type string `hcl:"type,label"`
Name string `hcl:"name,label"`
} `hcl:"thing,block"`
}{},
&hcl.BodySchema{
Blocks: []hcl.BlockHeaderSchema{
{
Type: "thing",
LabelNames: []string{"type", "name"},
},
},
},
false,
},
{
struct {
Thing []struct {
Type string `hcl:"type,label"`
Name string `hcl:"name,label"`
} `hcl:"thing,block"`
}{},
&hcl.BodySchema{
Blocks: []hcl.BlockHeaderSchema{
{
Type: "thing",
LabelNames: []string{"type", "name"},
},
},
},
false,
},
{
struct {
Thing *struct {
Type string `hcl:"type,label"`
Name string `hcl:"name,label"`
} `hcl:"thing,block"`
}{},
&hcl.BodySchema{
Blocks: []hcl.BlockHeaderSchema{
{
Type: "thing",
LabelNames: []string{"type", "name"},
},
},
},
false,
},
{
struct {
Thing struct {
Name string `hcl:"name,label"`
Something string `hcl:"something"`
} `hcl:"thing,block"`
}{},
&hcl.BodySchema{
Blocks: []hcl.BlockHeaderSchema{
{
Type: "thing",
LabelNames: []string{"name"},
},
},
},
false,
},
{
struct {
Doodad string `hcl:"doodad"`
Thing struct {
Name string `hcl:"name,label"`
} `hcl:"thing,block"`
}{},
&hcl.BodySchema{
Attributes: []hcl.AttributeSchema{
{
Name: "doodad",
Required: true,
},
},
Blocks: []hcl.BlockHeaderSchema{
{
Type: "thing",
LabelNames: []string{"name"},
},
},
},
false,
},
{
struct {
Doodad string `hcl:"doodad"`
Config string `hcl:",remain"`
}{},
&hcl.BodySchema{
Attributes: []hcl.AttributeSchema{
{
Name: "doodad",
Required: true,
},
},
},
true,
},
{
struct {
Expr hcl.Expression `hcl:"expr"`
}{},
&hcl.BodySchema{
Attributes: []hcl.AttributeSchema{
{
Name: "expr",
Required: false,
},
},
},
false,
},
{
struct {
Meh string `hcl:"meh,optional"`
}{},
&hcl.BodySchema{
Attributes: []hcl.AttributeSchema{
{
Name: "meh",
Required: false,
},
},
},
false,
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%#v", test.val), func(t *testing.T) {
schema, partial := ImpliedBodySchema(test.val)
if !reflect.DeepEqual(schema, test.wantSchema) {
t.Errorf(
"wrong schema\ngot: %s\nwant: %s",
spew.Sdump(schema), spew.Sdump(test.wantSchema),
)
}
if partial != test.wantPartial {
t.Errorf(
"wrong partial flag\ngot: %#v\nwant: %#v",
partial, test.wantPartial,
)
}
})
}
}

19
internal/gohcl/types.go Normal file
View File

@ -0,0 +1,19 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package gohcl
import (
"reflect"
"github.com/hashicorp/hcl/v2"
)
var victimExpr hcl.Expression
var victimBody hcl.Body
var exprType = reflect.TypeOf(&victimExpr).Elem()
var bodyType = reflect.TypeOf(&victimBody).Elem()
var blockType = reflect.TypeOf((*hcl.Block)(nil))
var attrType = reflect.TypeOf((*hcl.Attribute)(nil))
var attrsType = reflect.TypeOf(hcl.Attributes(nil))

View File

@ -3,14 +3,13 @@
// This is a fork of hashicorp/hcl/gohcl/decode.go that pulls out variable dependencies in attributes
package varhcl
package gohcl
import (
"fmt"
"reflect"
"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/gohcl"
)
func VariablesInBody(body hcl.Body, val interface{}) ([]hcl.Traversal, hcl.Diagnostics) {
@ -37,7 +36,7 @@ func findVariablesInBody(body hcl.Body, val reflect.Value) ([]hcl.Traversal, hcl
func findVariablesInBodyStruct(body hcl.Body, val reflect.Value) ([]hcl.Traversal, hcl.Diagnostics) {
var variables []hcl.Traversal
schema, partial := gohcl.ImpliedBodySchema(val.Interface())
schema, partial := ImpliedBodySchema(val.Interface())
var content *hcl.BodyContent
var diags hcl.Diagnostics

View File

@ -3,15 +3,15 @@
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package varhcl
package gohcl_test
import (
"fmt"
"testing"
"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/gohcl"
"github.com/hashicorp/hcl/v2/hclsyntax"
"github.com/opentofu/opentofu/internal/gohcl"
"github.com/zclconf/go-cty/cty"
)
@ -46,7 +46,7 @@ func Test(t *testing.T) {
println()
println("> Detect Variables")
vars, diags := VariablesInBody(file.Body, ob)
vars, diags := gohcl.VariablesInBody(file.Body, ob)
println(diags.Error())
for _, v := range vars {
ident := ""

View File

@ -1,85 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
// This is a partial copy of hashicorp/hcl/gohcl/schema.go to allow access to internal variables.
// vardecode.go should be upstreamed instead.
package varhcl
import (
"fmt"
"reflect"
"strings"
)
type fieldTags struct {
Attributes map[string]int
Blocks map[string]int
Labels []labelField
Remain *int
Body *int
Optional map[string]bool
}
type labelField struct {
FieldIndex int
Name string
}
func getFieldTags(ty reflect.Type) *fieldTags {
ret := &fieldTags{
Attributes: map[string]int{},
Blocks: map[string]int{},
Optional: map[string]bool{},
}
ct := ty.NumField()
for i := 0; i < ct; i++ {
field := ty.Field(i)
tag := field.Tag.Get("hcl")
if tag == "" {
continue
}
comma := strings.Index(tag, ",")
var name, kind string
if comma != -1 {
name = tag[:comma]
kind = tag[comma+1:]
} else {
name = tag
kind = "attr"
}
switch kind {
case "attr":
ret.Attributes[name] = i
case "block":
ret.Blocks[name] = i
case "label":
ret.Labels = append(ret.Labels, labelField{
FieldIndex: i,
Name: name,
})
case "remain":
if ret.Remain != nil {
panic("only one 'remain' tag is permitted")
}
idx := i // copy, because this loop will continue assigning to i
ret.Remain = &idx
case "body":
if ret.Body != nil {
panic("only one 'body' tag is permitted")
}
idx := i // copy, because this loop will continue assigning to i
ret.Body = &idx
case "optional":
ret.Attributes[name] = i
ret.Optional[name] = true
default:
panic(fmt.Sprintf("invalid hcl field tag kind %q on %s %q", kind, field.Type.String(), field.Name))
}
}
return ret
}

View File

@ -11,7 +11,7 @@ echo "==> Checking that code complies with static analysis requirements..."
skip="internal/legacy|backend/remote-state/"
# Skip generated code for protobufs.
skip=$skip"|internal/planproto|internal/tfplugin5|internal/tfplugin6"
skip=$skip"|internal/planproto|internal/tfplugin5|internal/tfplugin6|internal/gohcl"
packages=$(go list ./... | egrep -v ${skip})