github.com/opentofu/opentofu@v1.7.1/internal/backend/remote-state/s3/client.go (about)

     1  // Copyright (c) The OpenTofu Authors
     2  // SPDX-License-Identifier: MPL-2.0
     3  // Copyright (c) 2023 HashiCorp, Inc.
     4  // SPDX-License-Identifier: MPL-2.0
     5  
     6  package s3
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"crypto/md5"
    12  	"crypto/sha256"
    13  	"encoding/base64"
    14  	"encoding/hex"
    15  	"encoding/json"
    16  	"errors"
    17  	"fmt"
    18  	"io"
    19  	"log"
    20  	"time"
    21  
    22  	"github.com/aws/aws-sdk-go-v2/aws"
    23  	"github.com/aws/aws-sdk-go-v2/service/dynamodb"
    24  	dtypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
    25  	"github.com/aws/aws-sdk-go-v2/service/s3"
    26  	types "github.com/aws/aws-sdk-go-v2/service/s3/types"
    27  	multierror "github.com/hashicorp/go-multierror"
    28  	uuid "github.com/hashicorp/go-uuid"
    29  
    30  	"github.com/opentofu/opentofu/internal/states/remote"
    31  	"github.com/opentofu/opentofu/internal/states/statemgr"
    32  )
    33  
    34  // Store the last saved serial in dynamo with this suffix for consistency checks.
    35  const (
    36  	s3EncryptionAlgorithm  = "AES256"
    37  	stateIDSuffix          = "-md5"
    38  	s3ErrCodeInternalError = "InternalError"
    39  )
    40  
    41  type RemoteClient struct {
    42  	s3Client              *s3.Client
    43  	dynClient             *dynamodb.Client
    44  	bucketName            string
    45  	path                  string
    46  	serverSideEncryption  bool
    47  	customerEncryptionKey []byte
    48  	acl                   string
    49  	kmsKeyID              string
    50  	ddbTable              string
    51  
    52  	skipS3Checksum bool
    53  }
    54  
    55  var (
    56  	// The amount of time we will retry a state waiting for it to match the
    57  	// expected checksum.
    58  	consistencyRetryTimeout = 10 * time.Second
    59  
    60  	// delay when polling the state
    61  	consistencyRetryPollInterval = 2 * time.Second
    62  )
    63  
    64  // test hook called when checksums don't match
    65  var testChecksumHook func()
    66  
    67  func (c *RemoteClient) Get() (payload *remote.Payload, err error) {
    68  	ctx := context.TODO()
    69  	deadline := time.Now().Add(consistencyRetryTimeout)
    70  
    71  	// If we have a checksum, and the returned payload doesn't match, we retry
    72  	// up until deadline.
    73  	for {
    74  		payload, err = c.get(ctx)
    75  		if err != nil {
    76  			return nil, err
    77  		}
    78  
    79  		// If the remote state was manually removed the payload will be nil,
    80  		// but if there's still a digest entry for that state we will still try
    81  		// to compare the MD5 below.
    82  		var digest []byte
    83  		if payload != nil {
    84  			digest = payload.MD5
    85  		}
    86  
    87  		// verify that this state is what we expect
    88  		if expected, err := c.getMD5(ctx); err != nil {
    89  			log.Printf("[WARN] failed to fetch state md5: %s", err)
    90  		} else if len(expected) > 0 && !bytes.Equal(expected, digest) {
    91  			log.Printf("[WARN] state md5 mismatch: expected '%x', got '%x'", expected, digest)
    92  
    93  			if testChecksumHook != nil {
    94  				testChecksumHook()
    95  			}
    96  
    97  			if time.Now().Before(deadline) {
    98  				time.Sleep(consistencyRetryPollInterval)
    99  				log.Println("[INFO] retrying S3 RemoteClient.Get...")
   100  				continue
   101  			}
   102  
   103  			return nil, fmt.Errorf(errBadChecksumFmt, digest)
   104  		}
   105  
   106  		break
   107  	}
   108  
   109  	return payload, err
   110  }
   111  
   112  func (c *RemoteClient) get(ctx context.Context) (*remote.Payload, error) {
   113  	var output *s3.GetObjectOutput
   114  	var err error
   115  
   116  	ctx, _ = attachLoggerToContext(ctx)
   117  
   118  	inputHead := &s3.HeadObjectInput{
   119  		Bucket: &c.bucketName,
   120  		Key:    &c.path,
   121  	}
   122  
   123  	if c.serverSideEncryption && c.customerEncryptionKey != nil {
   124  		inputHead.SSECustomerKey = aws.String(base64.StdEncoding.EncodeToString(c.customerEncryptionKey))
   125  		inputHead.SSECustomerAlgorithm = aws.String(s3EncryptionAlgorithm)
   126  		inputHead.SSECustomerKeyMD5 = aws.String(c.getSSECustomerKeyMD5())
   127  	}
   128  
   129  	// Head works around some s3 compatible backends not handling missing GetObject requests correctly (ex: minio Get returns Missing Bucket)
   130  	_, err = c.s3Client.HeadObject(ctx, inputHead)
   131  	if err != nil {
   132  		var nb *types.NoSuchBucket
   133  		if errors.As(err, &nb) {
   134  			return nil, fmt.Errorf(errS3NoSuchBucket, err)
   135  		}
   136  
   137  		var nk *types.NotFound
   138  		if errors.As(err, &nk) {
   139  			return nil, nil
   140  		}
   141  
   142  		return nil, err
   143  	}
   144  
   145  	input := &s3.GetObjectInput{
   146  		Bucket: &c.bucketName,
   147  		Key:    &c.path,
   148  	}
   149  
   150  	if c.serverSideEncryption && c.customerEncryptionKey != nil {
   151  		input.SSECustomerKey = aws.String(base64.StdEncoding.EncodeToString(c.customerEncryptionKey))
   152  		input.SSECustomerAlgorithm = aws.String(s3EncryptionAlgorithm)
   153  		input.SSECustomerKeyMD5 = aws.String(c.getSSECustomerKeyMD5())
   154  	}
   155  
   156  	output, err = c.s3Client.GetObject(ctx, input)
   157  	if err != nil {
   158  		var nb *types.NoSuchBucket
   159  		if errors.As(err, &nb) {
   160  			return nil, fmt.Errorf(errS3NoSuchBucket, err)
   161  		}
   162  
   163  		var nk *types.NoSuchKey
   164  		if errors.As(err, &nk) {
   165  			return nil, nil
   166  		}
   167  
   168  		return nil, err
   169  	}
   170  
   171  	defer output.Body.Close()
   172  
   173  	buf := bytes.NewBuffer(nil)
   174  	if _, err := io.Copy(buf, output.Body); err != nil {
   175  		return nil, fmt.Errorf("Failed to read remote state: %w", err)
   176  	}
   177  
   178  	sum := md5.Sum(buf.Bytes())
   179  	payload := &remote.Payload{
   180  		Data: buf.Bytes(),
   181  		MD5:  sum[:],
   182  	}
   183  
   184  	// If there was no data, then return nil
   185  	if len(payload.Data) == 0 {
   186  		return nil, nil
   187  	}
   188  
   189  	return payload, nil
   190  }
   191  
   192  func (c *RemoteClient) Put(data []byte) error {
   193  	contentType := "application/json"
   194  	contentLength := int64(len(data))
   195  
   196  	i := &s3.PutObjectInput{
   197  		ContentType:   &contentType,
   198  		ContentLength: aws.Int64(contentLength),
   199  		Body:          bytes.NewReader(data),
   200  		Bucket:        &c.bucketName,
   201  		Key:           &c.path,
   202  	}
   203  
   204  	if !c.skipS3Checksum {
   205  		i.ChecksumAlgorithm = types.ChecksumAlgorithmSha256
   206  
   207  		// There is a conflict in the aws-go-sdk-v2 that prevents it from working with many s3 compatible services
   208  		// Since we can pre-compute the hash here, we can work around it.
   209  		// ref: https://github.com/aws/aws-sdk-go-v2/issues/1689
   210  		algo := sha256.New()
   211  		algo.Write(data)
   212  		sum64str := base64.StdEncoding.EncodeToString(algo.Sum(nil))
   213  		i.ChecksumSHA256 = &sum64str
   214  	}
   215  
   216  	if c.serverSideEncryption {
   217  		if c.kmsKeyID != "" {
   218  			i.SSEKMSKeyId = &c.kmsKeyID
   219  			i.ServerSideEncryption = types.ServerSideEncryptionAwsKms
   220  		} else if c.customerEncryptionKey != nil {
   221  			i.SSECustomerKey = aws.String(base64.StdEncoding.EncodeToString(c.customerEncryptionKey))
   222  			i.SSECustomerAlgorithm = aws.String(string(s3EncryptionAlgorithm))
   223  			i.SSECustomerKeyMD5 = aws.String(c.getSSECustomerKeyMD5())
   224  		} else {
   225  			i.ServerSideEncryption = s3EncryptionAlgorithm
   226  		}
   227  	}
   228  
   229  	if c.acl != "" {
   230  		i.ACL = types.ObjectCannedACL(c.acl)
   231  	}
   232  
   233  	log.Printf("[DEBUG] Uploading remote state to S3: %#v", i)
   234  
   235  	ctx := context.TODO()
   236  	ctx, _ = attachLoggerToContext(ctx)
   237  
   238  	_, err := c.s3Client.PutObject(ctx, i)
   239  	if err != nil {
   240  		return fmt.Errorf("failed to upload state: %w", err)
   241  	}
   242  
   243  	sum := md5.Sum(data)
   244  	if err := c.putMD5(ctx, sum[:]); err != nil {
   245  		// if this errors out, we unfortunately have to error out altogether,
   246  		// since the next Get will inevitably fail.
   247  		return fmt.Errorf("failed to store state MD5: %w", err)
   248  
   249  	}
   250  
   251  	return nil
   252  }
   253  
   254  func (c *RemoteClient) Delete() error {
   255  	ctx := context.TODO()
   256  	ctx, _ = attachLoggerToContext(ctx)
   257  
   258  	_, err := c.s3Client.DeleteObject(ctx, &s3.DeleteObjectInput{
   259  		Bucket: &c.bucketName,
   260  		Key:    &c.path,
   261  	})
   262  
   263  	if err != nil {
   264  		return err
   265  	}
   266  
   267  	if err := c.deleteMD5(ctx); err != nil {
   268  		log.Printf("error deleting state md5: %s", err)
   269  	}
   270  
   271  	return nil
   272  }
   273  
   274  func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) {
   275  	if c.ddbTable == "" {
   276  		return "", nil
   277  	}
   278  
   279  	info.Path = c.lockPath()
   280  
   281  	if info.ID == "" {
   282  		lockID, err := uuid.GenerateUUID()
   283  		if err != nil {
   284  			return "", err
   285  		}
   286  
   287  		info.ID = lockID
   288  	}
   289  
   290  	putParams := &dynamodb.PutItemInput{
   291  		Item: map[string]dtypes.AttributeValue{
   292  			"LockID": &dtypes.AttributeValueMemberS{Value: c.lockPath()},
   293  			"Info":   &dtypes.AttributeValueMemberS{Value: string(info.Marshal())},
   294  		},
   295  		TableName:           aws.String(c.ddbTable),
   296  		ConditionExpression: aws.String("attribute_not_exists(LockID)"),
   297  	}
   298  
   299  	ctx := context.TODO()
   300  	_, err := c.dynClient.PutItem(ctx, putParams)
   301  	if err != nil {
   302  		lockInfo, infoErr := c.getLockInfo(ctx)
   303  		if infoErr != nil {
   304  			err = multierror.Append(err, infoErr)
   305  		}
   306  
   307  		lockErr := &statemgr.LockError{
   308  			Err:  err,
   309  			Info: lockInfo,
   310  		}
   311  		return "", lockErr
   312  	}
   313  
   314  	return info.ID, nil
   315  }
   316  
   317  func (c *RemoteClient) getMD5(ctx context.Context) ([]byte, error) {
   318  	if c.ddbTable == "" {
   319  		return nil, nil
   320  	}
   321  
   322  	getParams := &dynamodb.GetItemInput{
   323  		Key: map[string]dtypes.AttributeValue{
   324  			"LockID": &dtypes.AttributeValueMemberS{Value: c.lockPath() + stateIDSuffix},
   325  		},
   326  		ProjectionExpression: aws.String("LockID, Digest"),
   327  		TableName:            aws.String(c.ddbTable),
   328  		ConsistentRead:       aws.Bool(true),
   329  	}
   330  
   331  	resp, err := c.dynClient.GetItem(ctx, getParams)
   332  	if err != nil {
   333  		return nil, err
   334  	}
   335  
   336  	var val string
   337  	if v, ok := resp.Item["Digest"]; ok {
   338  		if v, ok := v.(*dtypes.AttributeValueMemberS); ok {
   339  			val = v.Value
   340  		}
   341  	}
   342  
   343  	sum, err := hex.DecodeString(val)
   344  	if err != nil || len(sum) != md5.Size {
   345  		return nil, errors.New("invalid md5")
   346  	}
   347  
   348  	return sum, nil
   349  }
   350  
   351  // store the hash of the state so that clients can check for stale state files.
   352  func (c *RemoteClient) putMD5(ctx context.Context, sum []byte) error {
   353  	if c.ddbTable == "" {
   354  		return nil
   355  	}
   356  
   357  	if len(sum) != md5.Size {
   358  		return errors.New("invalid payload md5")
   359  	}
   360  
   361  	putParams := &dynamodb.PutItemInput{
   362  		Item: map[string]dtypes.AttributeValue{
   363  			"LockID": &dtypes.AttributeValueMemberS{Value: c.lockPath() + stateIDSuffix},
   364  			"Digest": &dtypes.AttributeValueMemberS{Value: hex.EncodeToString(sum)},
   365  		},
   366  		TableName: aws.String(c.ddbTable),
   367  	}
   368  	_, err := c.dynClient.PutItem(ctx, putParams)
   369  	if err != nil {
   370  		log.Printf("[WARN] failed to record state serial in dynamodb: %s", err)
   371  	}
   372  
   373  	return nil
   374  }
   375  
   376  // remove the hash value for a deleted state
   377  func (c *RemoteClient) deleteMD5(ctx context.Context) error {
   378  	if c.ddbTable == "" {
   379  		return nil
   380  	}
   381  
   382  	params := &dynamodb.DeleteItemInput{
   383  		Key: map[string]dtypes.AttributeValue{
   384  			"LockID": &dtypes.AttributeValueMemberS{Value: c.lockPath() + stateIDSuffix},
   385  		},
   386  		TableName: aws.String(c.ddbTable),
   387  	}
   388  	if _, err := c.dynClient.DeleteItem(ctx, params); err != nil {
   389  		return err
   390  	}
   391  	return nil
   392  }
   393  
   394  func (c *RemoteClient) getLockInfo(ctx context.Context) (*statemgr.LockInfo, error) {
   395  	getParams := &dynamodb.GetItemInput{
   396  		Key: map[string]dtypes.AttributeValue{
   397  			"LockID": &dtypes.AttributeValueMemberS{Value: c.lockPath()},
   398  		},
   399  		ProjectionExpression: aws.String("LockID, Info"),
   400  		TableName:            aws.String(c.ddbTable),
   401  		ConsistentRead:       aws.Bool(true),
   402  	}
   403  
   404  	resp, err := c.dynClient.GetItem(ctx, getParams)
   405  	if err != nil {
   406  		return nil, err
   407  	}
   408  
   409  	var infoData string
   410  	if v, ok := resp.Item["Info"]; ok {
   411  		if v, ok := v.(*dtypes.AttributeValueMemberS); ok {
   412  			infoData = v.Value
   413  		}
   414  	}
   415  
   416  	lockInfo := &statemgr.LockInfo{}
   417  	err = json.Unmarshal([]byte(infoData), lockInfo)
   418  	if err != nil {
   419  		return nil, err
   420  	}
   421  
   422  	return lockInfo, nil
   423  }
   424  
   425  func (c *RemoteClient) Unlock(id string) error {
   426  	if c.ddbTable == "" {
   427  		return nil
   428  	}
   429  
   430  	lockErr := &statemgr.LockError{}
   431  	ctx := context.TODO()
   432  
   433  	// TODO: store the path and lock ID in separate fields, and have proper
   434  	// projection expression only delete the lock if both match, rather than
   435  	// checking the ID from the info field first.
   436  	lockInfo, err := c.getLockInfo(ctx)
   437  	if err != nil {
   438  		lockErr.Err = fmt.Errorf("failed to retrieve lock info: %w", err)
   439  		return lockErr
   440  	}
   441  	lockErr.Info = lockInfo
   442  
   443  	if lockInfo.ID != id {
   444  		lockErr.Err = fmt.Errorf("lock id %q does not match existing lock", id)
   445  		return lockErr
   446  	}
   447  
   448  	params := &dynamodb.DeleteItemInput{
   449  		Key: map[string]dtypes.AttributeValue{
   450  			"LockID": &dtypes.AttributeValueMemberS{Value: c.lockPath()},
   451  		},
   452  		TableName: aws.String(c.ddbTable),
   453  	}
   454  	_, err = c.dynClient.DeleteItem(ctx, params)
   455  
   456  	if err != nil {
   457  		lockErr.Err = err
   458  		return lockErr
   459  	}
   460  	return nil
   461  }
   462  
   463  func (c *RemoteClient) lockPath() string {
   464  	return fmt.Sprintf("%s/%s", c.bucketName, c.path)
   465  }
   466  
   467  func (c *RemoteClient) getSSECustomerKeyMD5() string {
   468  	b := md5.Sum(c.customerEncryptionKey)
   469  	return base64.StdEncoding.EncodeToString(b[:])
   470  }
   471  
   472  const errBadChecksumFmt = `state data in S3 does not have the expected content.
   473  
   474  This may be caused by unusually long delays in S3 processing a previous state
   475  update.  Please wait for a minute or two and try again. If this problem
   476  persists, and neither S3 nor DynamoDB are experiencing an outage, you may need
   477  to manually verify the remote state and update the Digest value stored in the
   478  DynamoDB table to the following value: %x
   479  `
   480  
   481  const errS3NoSuchBucket = `S3 bucket does not exist.
   482  
   483  The referenced S3 bucket must have been previously created. If the S3 bucket
   484  was created within the last minute, please wait for a minute or two and try
   485  again.
   486  
   487  Error: %w
   488  `