github.com/terramate-io/tf@v0.0.0-20230830114523-fce866b4dfcd/backend/remote-state/s3/backend_state.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package s3
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"path"
    10  	"sort"
    11  	"strings"
    12  
    13  	"github.com/aws/aws-sdk-go/aws"
    14  	"github.com/aws/aws-sdk-go/aws/awserr"
    15  	"github.com/aws/aws-sdk-go/service/s3"
    16  
    17  	"github.com/terramate-io/tf/backend"
    18  	"github.com/terramate-io/tf/states"
    19  	"github.com/terramate-io/tf/states/remote"
    20  	"github.com/terramate-io/tf/states/statemgr"
    21  )
    22  
    23  func (b *Backend) Workspaces() ([]string, error) {
    24  	const maxKeys = 1000
    25  
    26  	prefix := ""
    27  
    28  	if b.workspaceKeyPrefix != "" {
    29  		prefix = b.workspaceKeyPrefix + "/"
    30  	}
    31  
    32  	params := &s3.ListObjectsInput{
    33  		Bucket:  &b.bucketName,
    34  		Prefix:  aws.String(prefix),
    35  		MaxKeys: aws.Int64(maxKeys),
    36  	}
    37  
    38  	wss := []string{backend.DefaultStateName}
    39  	err := b.s3Client.ListObjectsPages(params, func(page *s3.ListObjectsOutput, lastPage bool) bool {
    40  		for _, obj := range page.Contents {
    41  			ws := b.keyEnv(*obj.Key)
    42  			if ws != "" {
    43  				wss = append(wss, ws)
    44  			}
    45  		}
    46  		return !lastPage
    47  	})
    48  
    49  	if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == s3.ErrCodeNoSuchBucket {
    50  		return nil, fmt.Errorf(errS3NoSuchBucket, err)
    51  	}
    52  
    53  	sort.Strings(wss[1:])
    54  	return wss, nil
    55  }
    56  
    57  func (b *Backend) keyEnv(key string) string {
    58  	prefix := b.workspaceKeyPrefix
    59  
    60  	if prefix == "" {
    61  		parts := strings.SplitN(key, "/", 2)
    62  		if len(parts) > 1 && parts[1] == b.keyName {
    63  			return parts[0]
    64  		} else {
    65  			return ""
    66  		}
    67  	}
    68  
    69  	// add a slash to treat this as a directory
    70  	prefix += "/"
    71  
    72  	parts := strings.SplitAfterN(key, prefix, 2)
    73  	if len(parts) < 2 {
    74  		return ""
    75  	}
    76  
    77  	// shouldn't happen since we listed by prefix
    78  	if parts[0] != prefix {
    79  		return ""
    80  	}
    81  
    82  	parts = strings.SplitN(parts[1], "/", 2)
    83  
    84  	if len(parts) < 2 {
    85  		return ""
    86  	}
    87  
    88  	// not our key, so don't include it in our listing
    89  	if parts[1] != b.keyName {
    90  		return ""
    91  	}
    92  
    93  	return parts[0]
    94  }
    95  
    96  func (b *Backend) DeleteWorkspace(name string, _ bool) error {
    97  	if name == backend.DefaultStateName || name == "" {
    98  		return fmt.Errorf("can't delete default state")
    99  	}
   100  
   101  	client, err := b.remoteClient(name)
   102  	if err != nil {
   103  		return err
   104  	}
   105  
   106  	return client.Delete()
   107  }
   108  
   109  // get a remote client configured for this state
   110  func (b *Backend) remoteClient(name string) (*RemoteClient, error) {
   111  	if name == "" {
   112  		return nil, errors.New("missing state name")
   113  	}
   114  
   115  	client := &RemoteClient{
   116  		s3Client:              b.s3Client,
   117  		dynClient:             b.dynClient,
   118  		bucketName:            b.bucketName,
   119  		path:                  b.path(name),
   120  		serverSideEncryption:  b.serverSideEncryption,
   121  		customerEncryptionKey: b.customerEncryptionKey,
   122  		acl:                   b.acl,
   123  		kmsKeyID:              b.kmsKeyID,
   124  		ddbTable:              b.ddbTable,
   125  	}
   126  
   127  	return client, nil
   128  }
   129  
   130  func (b *Backend) StateMgr(name string) (statemgr.Full, error) {
   131  	client, err := b.remoteClient(name)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  
   136  	stateMgr := &remote.State{Client: client}
   137  	// Check to see if this state already exists.
   138  	// If we're trying to force-unlock a state, we can't take the lock before
   139  	// fetching the state. If the state doesn't exist, we have to assume this
   140  	// is a normal create operation, and take the lock at that point.
   141  	//
   142  	// If we need to force-unlock, but for some reason the state no longer
   143  	// exists, the user will have to use aws tools to manually fix the
   144  	// situation.
   145  	existing, err := b.Workspaces()
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  
   150  	exists := false
   151  	for _, s := range existing {
   152  		if s == name {
   153  			exists = true
   154  			break
   155  		}
   156  	}
   157  
   158  	// We need to create the object so it's listed by States.
   159  	if !exists {
   160  		// take a lock on this state while we write it
   161  		lockInfo := statemgr.NewLockInfo()
   162  		lockInfo.Operation = "init"
   163  		lockId, err := client.Lock(lockInfo)
   164  		if err != nil {
   165  			return nil, fmt.Errorf("failed to lock s3 state: %s", err)
   166  		}
   167  
   168  		// Local helper function so we can call it multiple places
   169  		lockUnlock := func(parent error) error {
   170  			if err := stateMgr.Unlock(lockId); err != nil {
   171  				return fmt.Errorf(strings.TrimSpace(errStateUnlock), lockId, err)
   172  			}
   173  			return parent
   174  		}
   175  
   176  		// Grab the value
   177  		// This is to ensure that no one beat us to writing a state between
   178  		// the `exists` check and taking the lock.
   179  		if err := stateMgr.RefreshState(); err != nil {
   180  			err = lockUnlock(err)
   181  			return nil, err
   182  		}
   183  
   184  		// If we have no state, we have to create an empty state
   185  		if v := stateMgr.State(); v == nil {
   186  			if err := stateMgr.WriteState(states.NewState()); err != nil {
   187  				err = lockUnlock(err)
   188  				return nil, err
   189  			}
   190  			if err := stateMgr.PersistState(nil); err != nil {
   191  				err = lockUnlock(err)
   192  				return nil, err
   193  			}
   194  		}
   195  
   196  		// Unlock, the state should now be initialized
   197  		if err := lockUnlock(nil); err != nil {
   198  			return nil, err
   199  		}
   200  
   201  	}
   202  
   203  	return stateMgr, nil
   204  }
   205  
   206  func (b *Backend) path(name string) string {
   207  	if name == backend.DefaultStateName {
   208  		return b.keyName
   209  	}
   210  
   211  	return path.Join(b.workspaceKeyPrefix, name, b.keyName)
   212  }
   213  
   214  const errStateUnlock = `
   215  Error unlocking S3 state. Lock ID: %s
   216  
   217  Error: %s
   218  
   219  You may have to force-unlock this state in order to use it again.
   220  `