opentofu/internal/backend/remote-state/s3/client_test.go
namgyalangmo cb2e9119aa
Update copyright notice (#1232)
Signed-off-by: namgyalangmo <75657887+namgyalangmo@users.noreply.github.com>
2024-02-08 09:48:59 +00:00

329 lines
9.0 KiB
Go

// Copyright (c) The OpenTofu Authors
// SPDX-License-Identifier: MPL-2.0
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package s3
import (
"bytes"
"context"
"crypto/md5"
"fmt"
"strings"
"testing"
"time"
"github.com/opentofu/opentofu/internal/backend"
"github.com/opentofu/opentofu/internal/states/remote"
"github.com/opentofu/opentofu/internal/states/statefile"
"github.com/opentofu/opentofu/internal/states/statemgr"
)
func TestRemoteClient_impl(t *testing.T) {
var _ remote.Client = new(RemoteClient)
var _ remote.ClientLocker = new(RemoteClient)
}
func TestRemoteClient(t *testing.T) {
testACC(t)
bucketName := fmt.Sprintf("%s-%x", testBucketPrefix, time.Now().Unix())
keyName := "testState"
b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"encrypt": true,
})).(*Backend)
ctx := context.TODO()
createS3Bucket(ctx, t, b.s3Client, bucketName, b.awsConfig.Region)
defer deleteS3Bucket(ctx, t, b.s3Client, bucketName)
state, err := b.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
remote.TestClient(t, state.(*remote.State).Client)
}
func TestRemoteClientLocks(t *testing.T) {
testACC(t)
bucketName := fmt.Sprintf("%s-%x", testBucketPrefix, time.Now().Unix())
keyName := "testState"
b1 := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"encrypt": true,
"dynamodb_table": bucketName,
})).(*Backend)
b2 := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"encrypt": true,
"dynamodb_table": bucketName,
})).(*Backend)
ctx := context.TODO()
createS3Bucket(ctx, t, b1.s3Client, bucketName, b1.awsConfig.Region)
defer deleteS3Bucket(ctx, t, b1.s3Client, bucketName)
createDynamoDBTable(ctx, t, b1.dynClient, bucketName)
defer deleteDynamoDBTable(ctx, t, b1.dynClient, bucketName)
s1, err := b1.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
s2, err := b2.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
remote.TestRemoteLocks(t, s1.(*remote.State).Client, s2.(*remote.State).Client)
}
// verify that we can unlock a state with an existing lock
func TestForceUnlock(t *testing.T) {
testACC(t)
bucketName := fmt.Sprintf("%s-force-%x", testBucketPrefix, time.Now().Unix())
keyName := "testState"
b1 := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"encrypt": true,
"dynamodb_table": bucketName,
})).(*Backend)
b2 := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"encrypt": true,
"dynamodb_table": bucketName,
})).(*Backend)
ctx := context.TODO()
createS3Bucket(ctx, t, b1.s3Client, bucketName, b1.awsConfig.Region)
defer deleteS3Bucket(ctx, t, b1.s3Client, bucketName)
createDynamoDBTable(ctx, t, b1.dynClient, bucketName)
defer deleteDynamoDBTable(ctx, t, b1.dynClient, bucketName)
// first test with default
s1, err := b1.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
info := statemgr.NewLockInfo()
info.Operation = "test"
info.Who = "clientA"
lockID, err := s1.Lock(info)
if err != nil {
t.Fatal("unable to get initial lock:", err)
}
// s1 is now locked, get the same state through s2 and unlock it
s2, err := b2.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal("failed to get default state to force unlock:", err)
}
if err := s2.Unlock(lockID); err != nil {
t.Fatal("failed to force-unlock default state")
}
// now try the same thing with a named state
// first test with default
s1, err = b1.StateMgr("test")
if err != nil {
t.Fatal(err)
}
info = statemgr.NewLockInfo()
info.Operation = "test"
info.Who = "clientA"
lockID, err = s1.Lock(info)
if err != nil {
t.Fatal("unable to get initial lock:", err)
}
// s1 is now locked, get the same state through s2 and unlock it
s2, err = b2.StateMgr("test")
if err != nil {
t.Fatal("failed to get named state to force unlock:", err)
}
if err = s2.Unlock(lockID); err != nil {
t.Fatal("failed to force-unlock named state")
}
}
func TestRemoteClient_clientMD5(t *testing.T) {
testACC(t)
bucketName := fmt.Sprintf("%s-%x", testBucketPrefix, time.Now().Unix())
keyName := "testState"
b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"dynamodb_table": bucketName,
})).(*Backend)
ctx := context.TODO()
createS3Bucket(ctx, t, b.s3Client, bucketName, b.awsConfig.Region)
defer deleteS3Bucket(ctx, t, b.s3Client, bucketName)
createDynamoDBTable(ctx, t, b.dynClient, bucketName)
defer deleteDynamoDBTable(ctx, t, b.dynClient, bucketName)
s, err := b.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client := s.(*remote.State).Client.(*RemoteClient)
sum := md5.Sum([]byte("test"))
if err := client.putMD5(ctx, sum[:]); err != nil {
t.Fatal(err)
}
getSum, err := client.getMD5(ctx)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(getSum, sum[:]) {
t.Fatalf("getMD5 returned the wrong checksum: expected %x, got %x", sum[:], getSum)
}
if err := client.deleteMD5(ctx); err != nil {
t.Fatal(err)
}
if getSum, err := client.getMD5(ctx); err == nil {
t.Fatalf("expected getMD5 error, got none. checksum: %x", getSum)
}
}
// verify that a client won't return a state with an incorrect checksum.
func TestRemoteClient_stateChecksum(t *testing.T) {
testACC(t)
bucketName := fmt.Sprintf("%s-%x", testBucketPrefix, time.Now().Unix())
keyName := "testState"
b1 := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"dynamodb_table": bucketName,
})).(*Backend)
ctx := context.TODO()
createS3Bucket(ctx, t, b1.s3Client, bucketName, b1.awsConfig.Region)
defer deleteS3Bucket(ctx, t, b1.s3Client, bucketName)
createDynamoDBTable(ctx, t, b1.dynClient, bucketName)
defer deleteDynamoDBTable(ctx, t, b1.dynClient, bucketName)
s1, err := b1.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client1 := s1.(*remote.State).Client
// create an old and new state version to persist
s := statemgr.TestFullInitialState()
sf := &statefile.File{State: s}
var oldState bytes.Buffer
if err := statefile.Write(sf, &oldState); err != nil {
t.Fatal(err)
}
sf.Serial++
var newState bytes.Buffer
if err := statefile.Write(sf, &newState); err != nil {
t.Fatal(err)
}
// Use b2 without a dynamodb_table to bypass the lock table to write the state directly.
// client2 will write the "incorrect" state, simulating s3 eventually consistency delays
b2 := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"key": keyName,
})).(*Backend)
s2, err := b2.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client2 := s2.(*remote.State).Client
// write the new state through client2 so that there is no checksum yet
if err := client2.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}
// verify that we can pull a state without a checksum
if _, err := client1.Get(); err != nil {
t.Fatal(err)
}
// write the new state back with its checksum
if err := client1.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}
// put an empty state in place to check for panics during get
if err := client2.Put([]byte{}); err != nil {
t.Fatal(err)
}
// remove the timeouts so we can fail immediately
origTimeout := consistencyRetryTimeout
origInterval := consistencyRetryPollInterval
defer func() {
consistencyRetryTimeout = origTimeout
consistencyRetryPollInterval = origInterval
}()
consistencyRetryTimeout = 0
consistencyRetryPollInterval = 0
// fetching an empty state through client1 should now error out due to a
// mismatched checksum.
if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) {
t.Fatalf("expected state checksum error: got %s", err)
}
// put the old state in place of the new, without updating the checksum
if err := client2.Put(oldState.Bytes()); err != nil {
t.Fatal(err)
}
// fetching the wrong state through client1 should now error out due to a
// mismatched checksum.
if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) {
t.Fatalf("expected state checksum error: got %s", err)
}
// update the state with the correct one after we Get again
testChecksumHook = func() {
if err := client2.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}
testChecksumHook = nil
}
consistencyRetryTimeout = origTimeout
// this final Get will fail to fail the checksum verification, the above
// callback will update the state with the correct version, and Get should
// retry automatically.
if _, err := client1.Get(); err != nil {
t.Fatal(err)
}
}