From 979bf5ce3f8f86ed07ceb5e87f0175ec9a7cec0c Mon Sep 17 00:00:00 2001 From: Christian Mesh Date: Thu, 28 Mar 2024 11:14:08 -0400 Subject: [PATCH] Fix #1407: Pass through metadata fields in state encryption (#1417) Signed-off-by: Christian Mesh --- internal/encryption/base.go | 4 +-- internal/encryption/example_test.go | 8 ++--- internal/encryption/plan.go | 2 +- internal/encryption/state.go | 52 +++++++++++++++++++++++++++-- 4 files changed, 57 insertions(+), 9 deletions(-) diff --git a/internal/encryption/base.go b/internal/encryption/base.go index 3579b4e3ae..922c217701 100644 --- a/internal/encryption/base.go +++ b/internal/encryption/base.go @@ -87,7 +87,7 @@ func IsEncryptionPayload(data []byte) (bool, error) { return es.Version != "", nil } -func (s *baseEncryption) encrypt(data []byte) ([]byte, error) { +func (s *baseEncryption) encrypt(data []byte, enhance func(basedata) interface{}) ([]byte, error) { // No configuration provided, don't do anything if s.target == nil { return data, nil @@ -117,7 +117,7 @@ func (s *baseEncryption) encrypt(data []byte) ([]byte, error) { Meta: s.encMeta, Data: encd, } - jsond, err := json.Marshal(es) + jsond, err := json.Marshal(enhance(es)) if err != nil { return nil, fmt.Errorf("unable to encode encrypted data as json: %w", err) } diff --git a/internal/encryption/example_test.go b/internal/encryption/example_test.go index fd37f75294..4fed215939 100644 --- a/internal/encryption/example_test.go +++ b/internal/encryption/example_test.go @@ -63,15 +63,15 @@ func Example() { sfe := enc.State() - // Encrypt the data, for this example we will be using the string "test", + // Encrypt the data, for this example we will be using the string `{"serial": 42, "lineage": "magic"}`, // but in a real world scenario this would be the plan file. - sourceData := []byte("test") + sourceData := []byte(`{"serial": 42, "lineage": "magic"}`) encrypted, err := sfe.EncryptState(sourceData) if err != nil { panic(err) } - if string(encrypted) == "test" { + if string(encrypted) == `{"serial": 42, "lineage": "magic"}` { panic("The data has not been encrypted!") } @@ -82,7 +82,7 @@ func Example() { } fmt.Printf("%s\n", decryptedState) - // Output: test + // Output: {"serial": 42, "lineage": "magic"} } func handleDiags(diags hcl.Diagnostics) { diff --git a/internal/encryption/plan.go b/internal/encryption/plan.go index 1088a3f5b5..a1bb5366cf 100644 --- a/internal/encryption/plan.go +++ b/internal/encryption/plan.go @@ -55,7 +55,7 @@ func newPlanEncryption(enc *encryption, target *config.TargetConfig, enforced bo } func (p planEncryption) EncryptPlan(data []byte) ([]byte, error) { - return p.base.encrypt(data) + return p.base.encrypt(data, func(base basedata) interface{} { return base }) } func (p planEncryption) DecryptPlan(data []byte) ([]byte, error) { diff --git a/internal/encryption/state.go b/internal/encryption/state.go index ea026f258c..91b1c69232 100644 --- a/internal/encryption/state.go +++ b/internal/encryption/state.go @@ -62,12 +62,32 @@ func newStateEncryption(enc *encryption, target *config.TargetConfig, enforced b return &stateEncryption{base}, diags } +type statedata struct { + Serial *int `json:"serial"` + Lineage string `json:"lineage"` +} + func (s *stateEncryption) EncryptState(plainState []byte) ([]byte, error) { - return s.base.encrypt(plainState) + var passthrough statedata + err := json.Unmarshal(plainState, &passthrough) + if err != nil { + return nil, err + } + + return s.base.encrypt(plainState, func(base basedata) interface{} { + // Merge together the base encryption data and the passthrough fields + return struct { + statedata + basedata + }{ + statedata: passthrough, + basedata: base, + } + }) } func (s *stateEncryption) DecryptState(encryptedState []byte) ([]byte, error) { - return s.base.decrypt(encryptedState, func(data []byte) error { + decryptedState, err := s.base.decrypt(encryptedState, func(data []byte) error { tmp := struct { FormatVersion string `json:"terraform_version"` }{} @@ -82,6 +102,34 @@ func (s *stateEncryption) DecryptState(encryptedState []byte) ([]byte, error) { // Probably a state file return nil }) + + if err != nil { + return nil, err + } + + // Make sure that the state passthrough fields match + var encrypted statedata + err = json.Unmarshal(encryptedState, &encrypted) + if err != nil { + return nil, err + } + var state statedata + err = json.Unmarshal(decryptedState, &state) + if err != nil { + return nil, err + } + + // TODO make encrypted.Serial non-optional. This is only for supporting alpha1 states! + if encrypted.Serial != nil && state.Serial != nil && *state.Serial != *encrypted.Serial { + return nil, fmt.Errorf("invalid state metadata, serial field mismatch %v vs %v", *encrypted.Serial, *state.Serial) + } + + // TODO make encrypted.Lineage non-optional. This is only for supporting alpha1 states! + if encrypted.Lineage != "" && state.Lineage != encrypted.Lineage { + return nil, fmt.Errorf("invalid state metadata, linage field mismatch %v vs %v", encrypted.Lineage, state.Lineage) + } + + return decryptedState, nil } func StateEncryptionDisabled() StateEncryption {