github.com/terramate-io/tf@v0.0.0-20230830114523-fce866b4dfcd/backend/remote-state/s3/client.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package s3
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/md5"
     9  	"encoding/base64"
    10  	"encoding/hex"
    11  	"encoding/json"
    12  	"errors"
    13  	"fmt"
    14  	"io"
    15  	"log"
    16  	"time"
    17  
    18  	"github.com/aws/aws-sdk-go/aws"
    19  	"github.com/aws/aws-sdk-go/aws/awserr"
    20  	"github.com/aws/aws-sdk-go/service/dynamodb"
    21  	"github.com/aws/aws-sdk-go/service/s3"
    22  	multierror "github.com/hashicorp/go-multierror"
    23  	uuid "github.com/hashicorp/go-uuid"
    24  	"github.com/terramate-io/tf/states/remote"
    25  	"github.com/terramate-io/tf/states/statemgr"
    26  )
    27  
    28  // Store the last saved serial in dynamo with this suffix for consistency checks.
    29  const (
    30  	s3EncryptionAlgorithm  = "AES256"
    31  	stateIDSuffix          = "-md5"
    32  	s3ErrCodeInternalError = "InternalError"
    33  )
    34  
    35  type RemoteClient struct {
    36  	s3Client              *s3.S3
    37  	dynClient             *dynamodb.DynamoDB
    38  	bucketName            string
    39  	path                  string
    40  	serverSideEncryption  bool
    41  	customerEncryptionKey []byte
    42  	acl                   string
    43  	kmsKeyID              string
    44  	ddbTable              string
    45  }
    46  
    47  var (
    48  	// The amount of time we will retry a state waiting for it to match the
    49  	// expected checksum.
    50  	consistencyRetryTimeout = 10 * time.Second
    51  
    52  	// delay when polling the state
    53  	consistencyRetryPollInterval = 2 * time.Second
    54  )
    55  
    56  // test hook called when checksums don't match
    57  var testChecksumHook func()
    58  
    59  func (c *RemoteClient) Get() (payload *remote.Payload, err error) {
    60  	deadline := time.Now().Add(consistencyRetryTimeout)
    61  
    62  	// If we have a checksum, and the returned payload doesn't match, we retry
    63  	// up until deadline.
    64  	for {
    65  		payload, err = c.get()
    66  		if err != nil {
    67  			return nil, err
    68  		}
    69  
    70  		// If the remote state was manually removed the payload will be nil,
    71  		// but if there's still a digest entry for that state we will still try
    72  		// to compare the MD5 below.
    73  		var digest []byte
    74  		if payload != nil {
    75  			digest = payload.MD5
    76  		}
    77  
    78  		// verify that this state is what we expect
    79  		if expected, err := c.getMD5(); err != nil {
    80  			log.Printf("[WARN] failed to fetch state md5: %s", err)
    81  		} else if len(expected) > 0 && !bytes.Equal(expected, digest) {
    82  			log.Printf("[WARN] state md5 mismatch: expected '%x', got '%x'", expected, digest)
    83  
    84  			if testChecksumHook != nil {
    85  				testChecksumHook()
    86  			}
    87  
    88  			if time.Now().Before(deadline) {
    89  				time.Sleep(consistencyRetryPollInterval)
    90  				log.Println("[INFO] retrying S3 RemoteClient.Get...")
    91  				continue
    92  			}
    93  
    94  			return nil, fmt.Errorf(errBadChecksumFmt, digest)
    95  		}
    96  
    97  		break
    98  	}
    99  
   100  	return payload, err
   101  }
   102  
   103  func (c *RemoteClient) get() (*remote.Payload, error) {
   104  	var output *s3.GetObjectOutput
   105  	var err error
   106  
   107  	input := &s3.GetObjectInput{
   108  		Bucket: &c.bucketName,
   109  		Key:    &c.path,
   110  	}
   111  
   112  	if c.serverSideEncryption && c.customerEncryptionKey != nil {
   113  		input.SetSSECustomerKey(string(c.customerEncryptionKey))
   114  		input.SetSSECustomerAlgorithm(s3EncryptionAlgorithm)
   115  		input.SetSSECustomerKeyMD5(c.getSSECustomerKeyMD5())
   116  	}
   117  
   118  	output, err = c.s3Client.GetObject(input)
   119  
   120  	if err != nil {
   121  		if awserr, ok := err.(awserr.Error); ok {
   122  			switch awserr.Code() {
   123  			case s3.ErrCodeNoSuchBucket:
   124  				return nil, fmt.Errorf(errS3NoSuchBucket, err)
   125  			case s3.ErrCodeNoSuchKey:
   126  				return nil, nil
   127  			}
   128  		}
   129  		return nil, err
   130  	}
   131  
   132  	defer output.Body.Close()
   133  
   134  	buf := bytes.NewBuffer(nil)
   135  	if _, err := io.Copy(buf, output.Body); err != nil {
   136  		return nil, fmt.Errorf("Failed to read remote state: %s", err)
   137  	}
   138  
   139  	sum := md5.Sum(buf.Bytes())
   140  	payload := &remote.Payload{
   141  		Data: buf.Bytes(),
   142  		MD5:  sum[:],
   143  	}
   144  
   145  	// If there was no data, then return nil
   146  	if len(payload.Data) == 0 {
   147  		return nil, nil
   148  	}
   149  
   150  	return payload, nil
   151  }
   152  
   153  func (c *RemoteClient) Put(data []byte) error {
   154  	contentType := "application/json"
   155  	contentLength := int64(len(data))
   156  
   157  	i := &s3.PutObjectInput{
   158  		ContentType:   &contentType,
   159  		ContentLength: &contentLength,
   160  		Body:          bytes.NewReader(data),
   161  		Bucket:        &c.bucketName,
   162  		Key:           &c.path,
   163  	}
   164  
   165  	if c.serverSideEncryption {
   166  		if c.kmsKeyID != "" {
   167  			i.SSEKMSKeyId = &c.kmsKeyID
   168  			i.ServerSideEncryption = aws.String("aws:kms")
   169  		} else if c.customerEncryptionKey != nil {
   170  			i.SetSSECustomerKey(string(c.customerEncryptionKey))
   171  			i.SetSSECustomerAlgorithm(s3EncryptionAlgorithm)
   172  			i.SetSSECustomerKeyMD5(c.getSSECustomerKeyMD5())
   173  		} else {
   174  			i.ServerSideEncryption = aws.String(s3EncryptionAlgorithm)
   175  		}
   176  	}
   177  
   178  	if c.acl != "" {
   179  		i.ACL = aws.String(c.acl)
   180  	}
   181  
   182  	log.Printf("[DEBUG] Uploading remote state to S3: %#v", i)
   183  
   184  	_, err := c.s3Client.PutObject(i)
   185  	if err != nil {
   186  		return fmt.Errorf("failed to upload state: %s", err)
   187  	}
   188  
   189  	sum := md5.Sum(data)
   190  	if err := c.putMD5(sum[:]); err != nil {
   191  		// if this errors out, we unfortunately have to error out altogether,
   192  		// since the next Get will inevitably fail.
   193  		return fmt.Errorf("failed to store state MD5: %s", err)
   194  
   195  	}
   196  
   197  	return nil
   198  }
   199  
   200  func (c *RemoteClient) Delete() error {
   201  	_, err := c.s3Client.DeleteObject(&s3.DeleteObjectInput{
   202  		Bucket: &c.bucketName,
   203  		Key:    &c.path,
   204  	})
   205  
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	if err := c.deleteMD5(); err != nil {
   211  		log.Printf("error deleting state md5: %s", err)
   212  	}
   213  
   214  	return nil
   215  }
   216  
   217  func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) {
   218  	if c.ddbTable == "" {
   219  		return "", nil
   220  	}
   221  
   222  	info.Path = c.lockPath()
   223  
   224  	if info.ID == "" {
   225  		lockID, err := uuid.GenerateUUID()
   226  		if err != nil {
   227  			return "", err
   228  		}
   229  
   230  		info.ID = lockID
   231  	}
   232  
   233  	putParams := &dynamodb.PutItemInput{
   234  		Item: map[string]*dynamodb.AttributeValue{
   235  			"LockID": {S: aws.String(c.lockPath())},
   236  			"Info":   {S: aws.String(string(info.Marshal()))},
   237  		},
   238  		TableName:           aws.String(c.ddbTable),
   239  		ConditionExpression: aws.String("attribute_not_exists(LockID)"),
   240  	}
   241  	_, err := c.dynClient.PutItem(putParams)
   242  
   243  	if err != nil {
   244  		lockInfo, infoErr := c.getLockInfo()
   245  		if infoErr != nil {
   246  			err = multierror.Append(err, infoErr)
   247  		}
   248  
   249  		lockErr := &statemgr.LockError{
   250  			Err:  err,
   251  			Info: lockInfo,
   252  		}
   253  		return "", lockErr
   254  	}
   255  
   256  	return info.ID, nil
   257  }
   258  
   259  func (c *RemoteClient) getMD5() ([]byte, error) {
   260  	if c.ddbTable == "" {
   261  		return nil, nil
   262  	}
   263  
   264  	getParams := &dynamodb.GetItemInput{
   265  		Key: map[string]*dynamodb.AttributeValue{
   266  			"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
   267  		},
   268  		ProjectionExpression: aws.String("LockID, Digest"),
   269  		TableName:            aws.String(c.ddbTable),
   270  		ConsistentRead:       aws.Bool(true),
   271  	}
   272  
   273  	resp, err := c.dynClient.GetItem(getParams)
   274  	if err != nil {
   275  		return nil, err
   276  	}
   277  
   278  	var val string
   279  	if v, ok := resp.Item["Digest"]; ok && v.S != nil {
   280  		val = *v.S
   281  	}
   282  
   283  	sum, err := hex.DecodeString(val)
   284  	if err != nil || len(sum) != md5.Size {
   285  		return nil, errors.New("invalid md5")
   286  	}
   287  
   288  	return sum, nil
   289  }
   290  
   291  // store the hash of the state so that clients can check for stale state files.
   292  func (c *RemoteClient) putMD5(sum []byte) error {
   293  	if c.ddbTable == "" {
   294  		return nil
   295  	}
   296  
   297  	if len(sum) != md5.Size {
   298  		return errors.New("invalid payload md5")
   299  	}
   300  
   301  	putParams := &dynamodb.PutItemInput{
   302  		Item: map[string]*dynamodb.AttributeValue{
   303  			"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
   304  			"Digest": {S: aws.String(hex.EncodeToString(sum))},
   305  		},
   306  		TableName: aws.String(c.ddbTable),
   307  	}
   308  	_, err := c.dynClient.PutItem(putParams)
   309  	if err != nil {
   310  		log.Printf("[WARN] failed to record state serial in dynamodb: %s", err)
   311  	}
   312  
   313  	return nil
   314  }
   315  
   316  // remove the hash value for a deleted state
   317  func (c *RemoteClient) deleteMD5() error {
   318  	if c.ddbTable == "" {
   319  		return nil
   320  	}
   321  
   322  	params := &dynamodb.DeleteItemInput{
   323  		Key: map[string]*dynamodb.AttributeValue{
   324  			"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
   325  		},
   326  		TableName: aws.String(c.ddbTable),
   327  	}
   328  	if _, err := c.dynClient.DeleteItem(params); err != nil {
   329  		return err
   330  	}
   331  	return nil
   332  }
   333  
   334  func (c *RemoteClient) getLockInfo() (*statemgr.LockInfo, error) {
   335  	getParams := &dynamodb.GetItemInput{
   336  		Key: map[string]*dynamodb.AttributeValue{
   337  			"LockID": {S: aws.String(c.lockPath())},
   338  		},
   339  		ProjectionExpression: aws.String("LockID, Info"),
   340  		TableName:            aws.String(c.ddbTable),
   341  		ConsistentRead:       aws.Bool(true),
   342  	}
   343  
   344  	resp, err := c.dynClient.GetItem(getParams)
   345  	if err != nil {
   346  		return nil, err
   347  	}
   348  
   349  	var infoData string
   350  	if v, ok := resp.Item["Info"]; ok && v.S != nil {
   351  		infoData = *v.S
   352  	}
   353  
   354  	lockInfo := &statemgr.LockInfo{}
   355  	err = json.Unmarshal([]byte(infoData), lockInfo)
   356  	if err != nil {
   357  		return nil, err
   358  	}
   359  
   360  	return lockInfo, nil
   361  }
   362  
   363  func (c *RemoteClient) Unlock(id string) error {
   364  	if c.ddbTable == "" {
   365  		return nil
   366  	}
   367  
   368  	lockErr := &statemgr.LockError{}
   369  
   370  	// TODO: store the path and lock ID in separate fields, and have proper
   371  	// projection expression only delete the lock if both match, rather than
   372  	// checking the ID from the info field first.
   373  	lockInfo, err := c.getLockInfo()
   374  	if err != nil {
   375  		lockErr.Err = fmt.Errorf("failed to retrieve lock info: %s", err)
   376  		return lockErr
   377  	}
   378  	lockErr.Info = lockInfo
   379  
   380  	if lockInfo.ID != id {
   381  		lockErr.Err = fmt.Errorf("lock id %q does not match existing lock", id)
   382  		return lockErr
   383  	}
   384  
   385  	params := &dynamodb.DeleteItemInput{
   386  		Key: map[string]*dynamodb.AttributeValue{
   387  			"LockID": {S: aws.String(c.lockPath())},
   388  		},
   389  		TableName: aws.String(c.ddbTable),
   390  	}
   391  	_, err = c.dynClient.DeleteItem(params)
   392  
   393  	if err != nil {
   394  		lockErr.Err = err
   395  		return lockErr
   396  	}
   397  	return nil
   398  }
   399  
   400  func (c *RemoteClient) lockPath() string {
   401  	return fmt.Sprintf("%s/%s", c.bucketName, c.path)
   402  }
   403  
   404  func (c *RemoteClient) getSSECustomerKeyMD5() string {
   405  	b := md5.Sum(c.customerEncryptionKey)
   406  	return base64.StdEncoding.EncodeToString(b[:])
   407  }
   408  
   409  const errBadChecksumFmt = `state data in S3 does not have the expected content.
   410  
   411  This may be caused by unusually long delays in S3 processing a previous state
   412  update.  Please wait for a minute or two and try again. If this problem
   413  persists, and neither S3 nor DynamoDB are experiencing an outage, you may need
   414  to manually verify the remote state and update the Digest value stored in the
   415  DynamoDB table to the following value: %x
   416  `
   417  
   418  const errS3NoSuchBucket = `S3 bucket does not exist.
   419  
   420  The referenced S3 bucket must have been previously created. If the S3 bucket
   421  was created within the last minute, please wait for a minute or two and try
   422  again.
   423  
   424  Error: %s
   425  `