github.com/opentofu/opentofu@v1.7.1/internal/backend/remote-state/s3/client_test.go (about)

     1  // Copyright (c) The OpenTofu Authors
     2  // SPDX-License-Identifier: MPL-2.0
     3  // Copyright (c) 2023 HashiCorp, Inc.
     4  // SPDX-License-Identifier: MPL-2.0
     5  
     6  package s3
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"crypto/md5"
    12  	"fmt"
    13  	"strings"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/opentofu/opentofu/internal/backend"
    18  	"github.com/opentofu/opentofu/internal/encryption"
    19  	"github.com/opentofu/opentofu/internal/states/remote"
    20  	"github.com/opentofu/opentofu/internal/states/statefile"
    21  	"github.com/opentofu/opentofu/internal/states/statemgr"
    22  )
    23  
    24  func TestRemoteClient_impl(t *testing.T) {
    25  	var _ remote.Client = new(RemoteClient)
    26  	var _ remote.ClientLocker = new(RemoteClient)
    27  }
    28  
    29  func TestRemoteClient(t *testing.T) {
    30  	testACC(t)
    31  	bucketName := fmt.Sprintf("%s-%x", testBucketPrefix, time.Now().Unix())
    32  	keyName := "testState"
    33  
    34  	b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), backend.TestWrapConfig(map[string]interface{}{
    35  		"bucket":  bucketName,
    36  		"key":     keyName,
    37  		"encrypt": true,
    38  	})).(*Backend)
    39  
    40  	ctx := context.TODO()
    41  	createS3Bucket(ctx, t, b.s3Client, bucketName, b.awsConfig.Region)
    42  	defer deleteS3Bucket(ctx, t, b.s3Client, bucketName)
    43  
    44  	state, err := b.StateMgr(backend.DefaultStateName)
    45  	if err != nil {
    46  		t.Fatal(err)
    47  	}
    48  
    49  	remote.TestClient(t, state.(*remote.State).Client)
    50  }
    51  
    52  func TestRemoteClientLocks(t *testing.T) {
    53  	testACC(t)
    54  	bucketName := fmt.Sprintf("%s-%x", testBucketPrefix, time.Now().Unix())
    55  	keyName := "testState"
    56  
    57  	b1 := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), backend.TestWrapConfig(map[string]interface{}{
    58  		"bucket":         bucketName,
    59  		"key":            keyName,
    60  		"encrypt":        true,
    61  		"dynamodb_table": bucketName,
    62  	})).(*Backend)
    63  
    64  	b2 := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), backend.TestWrapConfig(map[string]interface{}{
    65  		"bucket":         bucketName,
    66  		"key":            keyName,
    67  		"encrypt":        true,
    68  		"dynamodb_table": bucketName,
    69  	})).(*Backend)
    70  
    71  	ctx := context.TODO()
    72  	createS3Bucket(ctx, t, b1.s3Client, bucketName, b1.awsConfig.Region)
    73  	defer deleteS3Bucket(ctx, t, b1.s3Client, bucketName)
    74  	createDynamoDBTable(ctx, t, b1.dynClient, bucketName)
    75  	defer deleteDynamoDBTable(ctx, t, b1.dynClient, bucketName)
    76  
    77  	s1, err := b1.StateMgr(backend.DefaultStateName)
    78  	if err != nil {
    79  		t.Fatal(err)
    80  	}
    81  
    82  	s2, err := b2.StateMgr(backend.DefaultStateName)
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  
    87  	remote.TestRemoteLocks(t, s1.(*remote.State).Client, s2.(*remote.State).Client)
    88  }
    89  
    90  // verify that we can unlock a state with an existing lock
    91  func TestForceUnlock(t *testing.T) {
    92  	testACC(t)
    93  	bucketName := fmt.Sprintf("%s-force-%x", testBucketPrefix, time.Now().Unix())
    94  	keyName := "testState"
    95  
    96  	b1 := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), backend.TestWrapConfig(map[string]interface{}{
    97  		"bucket":         bucketName,
    98  		"key":            keyName,
    99  		"encrypt":        true,
   100  		"dynamodb_table": bucketName,
   101  	})).(*Backend)
   102  
   103  	b2 := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), backend.TestWrapConfig(map[string]interface{}{
   104  		"bucket":         bucketName,
   105  		"key":            keyName,
   106  		"encrypt":        true,
   107  		"dynamodb_table": bucketName,
   108  	})).(*Backend)
   109  
   110  	ctx := context.TODO()
   111  	createS3Bucket(ctx, t, b1.s3Client, bucketName, b1.awsConfig.Region)
   112  	defer deleteS3Bucket(ctx, t, b1.s3Client, bucketName)
   113  	createDynamoDBTable(ctx, t, b1.dynClient, bucketName)
   114  	defer deleteDynamoDBTable(ctx, t, b1.dynClient, bucketName)
   115  
   116  	// first test with default
   117  	s1, err := b1.StateMgr(backend.DefaultStateName)
   118  	if err != nil {
   119  		t.Fatal(err)
   120  	}
   121  
   122  	info := statemgr.NewLockInfo()
   123  	info.Operation = "test"
   124  	info.Who = "clientA"
   125  
   126  	lockID, err := s1.Lock(info)
   127  	if err != nil {
   128  		t.Fatal("unable to get initial lock:", err)
   129  	}
   130  
   131  	// s1 is now locked, get the same state through s2 and unlock it
   132  	s2, err := b2.StateMgr(backend.DefaultStateName)
   133  	if err != nil {
   134  		t.Fatal("failed to get default state to force unlock:", err)
   135  	}
   136  
   137  	if err := s2.Unlock(lockID); err != nil {
   138  		t.Fatal("failed to force-unlock default state")
   139  	}
   140  
   141  	// now try the same thing with a named state
   142  	// first test with default
   143  	s1, err = b1.StateMgr("test")
   144  	if err != nil {
   145  		t.Fatal(err)
   146  	}
   147  
   148  	info = statemgr.NewLockInfo()
   149  	info.Operation = "test"
   150  	info.Who = "clientA"
   151  
   152  	lockID, err = s1.Lock(info)
   153  	if err != nil {
   154  		t.Fatal("unable to get initial lock:", err)
   155  	}
   156  
   157  	// s1 is now locked, get the same state through s2 and unlock it
   158  	s2, err = b2.StateMgr("test")
   159  	if err != nil {
   160  		t.Fatal("failed to get named state to force unlock:", err)
   161  	}
   162  
   163  	if err = s2.Unlock(lockID); err != nil {
   164  		t.Fatal("failed to force-unlock named state")
   165  	}
   166  }
   167  
   168  func TestRemoteClient_clientMD5(t *testing.T) {
   169  	testACC(t)
   170  
   171  	bucketName := fmt.Sprintf("%s-%x", testBucketPrefix, time.Now().Unix())
   172  	keyName := "testState"
   173  
   174  	b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), backend.TestWrapConfig(map[string]interface{}{
   175  		"bucket":         bucketName,
   176  		"key":            keyName,
   177  		"dynamodb_table": bucketName,
   178  	})).(*Backend)
   179  
   180  	ctx := context.TODO()
   181  	createS3Bucket(ctx, t, b.s3Client, bucketName, b.awsConfig.Region)
   182  	defer deleteS3Bucket(ctx, t, b.s3Client, bucketName)
   183  	createDynamoDBTable(ctx, t, b.dynClient, bucketName)
   184  	defer deleteDynamoDBTable(ctx, t, b.dynClient, bucketName)
   185  
   186  	s, err := b.StateMgr(backend.DefaultStateName)
   187  	if err != nil {
   188  		t.Fatal(err)
   189  	}
   190  	client := s.(*remote.State).Client.(*RemoteClient)
   191  
   192  	sum := md5.Sum([]byte("test"))
   193  
   194  	if err := client.putMD5(ctx, sum[:]); err != nil {
   195  		t.Fatal(err)
   196  	}
   197  
   198  	getSum, err := client.getMD5(ctx)
   199  	if err != nil {
   200  		t.Fatal(err)
   201  	}
   202  
   203  	if !bytes.Equal(getSum, sum[:]) {
   204  		t.Fatalf("getMD5 returned the wrong checksum: expected %x, got %x", sum[:], getSum)
   205  	}
   206  
   207  	if err := client.deleteMD5(ctx); err != nil {
   208  		t.Fatal(err)
   209  	}
   210  
   211  	if getSum, err := client.getMD5(ctx); err == nil {
   212  		t.Fatalf("expected getMD5 error, got none. checksum: %x", getSum)
   213  	}
   214  }
   215  
   216  // verify that a client won't return a state with an incorrect checksum.
   217  func TestRemoteClient_stateChecksum(t *testing.T) {
   218  	testACC(t)
   219  
   220  	bucketName := fmt.Sprintf("%s-%x", testBucketPrefix, time.Now().Unix())
   221  	keyName := "testState"
   222  
   223  	b1 := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), backend.TestWrapConfig(map[string]interface{}{
   224  		"bucket":         bucketName,
   225  		"key":            keyName,
   226  		"dynamodb_table": bucketName,
   227  	})).(*Backend)
   228  
   229  	ctx := context.TODO()
   230  	createS3Bucket(ctx, t, b1.s3Client, bucketName, b1.awsConfig.Region)
   231  	defer deleteS3Bucket(ctx, t, b1.s3Client, bucketName)
   232  	createDynamoDBTable(ctx, t, b1.dynClient, bucketName)
   233  	defer deleteDynamoDBTable(ctx, t, b1.dynClient, bucketName)
   234  
   235  	s1, err := b1.StateMgr(backend.DefaultStateName)
   236  	if err != nil {
   237  		t.Fatal(err)
   238  	}
   239  	client1 := s1.(*remote.State).Client
   240  
   241  	// create an old and new state version to persist
   242  	s := statemgr.TestFullInitialState()
   243  	sf := &statefile.File{State: s}
   244  	var oldState bytes.Buffer
   245  	if err := statefile.Write(sf, &oldState, encryption.StateEncryptionDisabled()); err != nil {
   246  		t.Fatal(err)
   247  	}
   248  	sf.Serial++
   249  	var newState bytes.Buffer
   250  	if err := statefile.Write(sf, &newState, encryption.StateEncryptionDisabled()); err != nil {
   251  		t.Fatal(err)
   252  	}
   253  
   254  	// Use b2 without a dynamodb_table to bypass the lock table to write the state directly.
   255  	// client2 will write the "incorrect" state, simulating s3 eventually consistency delays
   256  	b2 := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), backend.TestWrapConfig(map[string]interface{}{
   257  		"bucket": bucketName,
   258  		"key":    keyName,
   259  	})).(*Backend)
   260  	s2, err := b2.StateMgr(backend.DefaultStateName)
   261  	if err != nil {
   262  		t.Fatal(err)
   263  	}
   264  	client2 := s2.(*remote.State).Client
   265  
   266  	// write the new state through client2 so that there is no checksum yet
   267  	if err := client2.Put(newState.Bytes()); err != nil {
   268  		t.Fatal(err)
   269  	}
   270  
   271  	// verify that we can pull a state without a checksum
   272  	if _, err := client1.Get(); err != nil {
   273  		t.Fatal(err)
   274  	}
   275  
   276  	// write the new state back with its checksum
   277  	if err := client1.Put(newState.Bytes()); err != nil {
   278  		t.Fatal(err)
   279  	}
   280  
   281  	// put an empty state in place to check for panics during get
   282  	if err := client2.Put([]byte{}); err != nil {
   283  		t.Fatal(err)
   284  	}
   285  
   286  	// remove the timeouts so we can fail immediately
   287  	origTimeout := consistencyRetryTimeout
   288  	origInterval := consistencyRetryPollInterval
   289  	defer func() {
   290  		consistencyRetryTimeout = origTimeout
   291  		consistencyRetryPollInterval = origInterval
   292  	}()
   293  	consistencyRetryTimeout = 0
   294  	consistencyRetryPollInterval = 0
   295  
   296  	// fetching an empty state through client1 should now error out due to a
   297  	// mismatched checksum.
   298  	if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) {
   299  		t.Fatalf("expected state checksum error: got %s", err)
   300  	}
   301  
   302  	// put the old state in place of the new, without updating the checksum
   303  	if err := client2.Put(oldState.Bytes()); err != nil {
   304  		t.Fatal(err)
   305  	}
   306  
   307  	// fetching the wrong state through client1 should now error out due to a
   308  	// mismatched checksum.
   309  	if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) {
   310  		t.Fatalf("expected state checksum error: got %s", err)
   311  	}
   312  
   313  	// update the state with the correct one after we Get again
   314  	testChecksumHook = func() {
   315  		if err := client2.Put(newState.Bytes()); err != nil {
   316  			t.Fatal(err)
   317  		}
   318  		testChecksumHook = nil
   319  	}
   320  
   321  	consistencyRetryTimeout = origTimeout
   322  
   323  	// this final Get will fail to fail the checksum verification, the above
   324  	// callback will update the state with the correct version, and Get should
   325  	// retry automatically.
   326  	if _, err := client1.Get(); err != nil {
   327  		t.Fatal(err)
   328  	}
   329  }