github.com/terramate-io/tf@v0.0.0-20230830114523-fce866b4dfcd/backend/remote-state/s3/client.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package s3 5 6 import ( 7 "bytes" 8 "crypto/md5" 9 "encoding/base64" 10 "encoding/hex" 11 "encoding/json" 12 "errors" 13 "fmt" 14 "io" 15 "log" 16 "time" 17 18 "github.com/aws/aws-sdk-go/aws" 19 "github.com/aws/aws-sdk-go/aws/awserr" 20 "github.com/aws/aws-sdk-go/service/dynamodb" 21 "github.com/aws/aws-sdk-go/service/s3" 22 multierror "github.com/hashicorp/go-multierror" 23 uuid "github.com/hashicorp/go-uuid" 24 "github.com/terramate-io/tf/states/remote" 25 "github.com/terramate-io/tf/states/statemgr" 26 ) 27 28 // Store the last saved serial in dynamo with this suffix for consistency checks. 29 const ( 30 s3EncryptionAlgorithm = "AES256" 31 stateIDSuffix = "-md5" 32 s3ErrCodeInternalError = "InternalError" 33 ) 34 35 type RemoteClient struct { 36 s3Client *s3.S3 37 dynClient *dynamodb.DynamoDB 38 bucketName string 39 path string 40 serverSideEncryption bool 41 customerEncryptionKey []byte 42 acl string 43 kmsKeyID string 44 ddbTable string 45 } 46 47 var ( 48 // The amount of time we will retry a state waiting for it to match the 49 // expected checksum. 50 consistencyRetryTimeout = 10 * time.Second 51 52 // delay when polling the state 53 consistencyRetryPollInterval = 2 * time.Second 54 ) 55 56 // test hook called when checksums don't match 57 var testChecksumHook func() 58 59 func (c *RemoteClient) Get() (payload *remote.Payload, err error) { 60 deadline := time.Now().Add(consistencyRetryTimeout) 61 62 // If we have a checksum, and the returned payload doesn't match, we retry 63 // up until deadline. 64 for { 65 payload, err = c.get() 66 if err != nil { 67 return nil, err 68 } 69 70 // If the remote state was manually removed the payload will be nil, 71 // but if there's still a digest entry for that state we will still try 72 // to compare the MD5 below. 73 var digest []byte 74 if payload != nil { 75 digest = payload.MD5 76 } 77 78 // verify that this state is what we expect 79 if expected, err := c.getMD5(); err != nil { 80 log.Printf("[WARN] failed to fetch state md5: %s", err) 81 } else if len(expected) > 0 && !bytes.Equal(expected, digest) { 82 log.Printf("[WARN] state md5 mismatch: expected '%x', got '%x'", expected, digest) 83 84 if testChecksumHook != nil { 85 testChecksumHook() 86 } 87 88 if time.Now().Before(deadline) { 89 time.Sleep(consistencyRetryPollInterval) 90 log.Println("[INFO] retrying S3 RemoteClient.Get...") 91 continue 92 } 93 94 return nil, fmt.Errorf(errBadChecksumFmt, digest) 95 } 96 97 break 98 } 99 100 return payload, err 101 } 102 103 func (c *RemoteClient) get() (*remote.Payload, error) { 104 var output *s3.GetObjectOutput 105 var err error 106 107 input := &s3.GetObjectInput{ 108 Bucket: &c.bucketName, 109 Key: &c.path, 110 } 111 112 if c.serverSideEncryption && c.customerEncryptionKey != nil { 113 input.SetSSECustomerKey(string(c.customerEncryptionKey)) 114 input.SetSSECustomerAlgorithm(s3EncryptionAlgorithm) 115 input.SetSSECustomerKeyMD5(c.getSSECustomerKeyMD5()) 116 } 117 118 output, err = c.s3Client.GetObject(input) 119 120 if err != nil { 121 if awserr, ok := err.(awserr.Error); ok { 122 switch awserr.Code() { 123 case s3.ErrCodeNoSuchBucket: 124 return nil, fmt.Errorf(errS3NoSuchBucket, err) 125 case s3.ErrCodeNoSuchKey: 126 return nil, nil 127 } 128 } 129 return nil, err 130 } 131 132 defer output.Body.Close() 133 134 buf := bytes.NewBuffer(nil) 135 if _, err := io.Copy(buf, output.Body); err != nil { 136 return nil, fmt.Errorf("Failed to read remote state: %s", err) 137 } 138 139 sum := md5.Sum(buf.Bytes()) 140 payload := &remote.Payload{ 141 Data: buf.Bytes(), 142 MD5: sum[:], 143 } 144 145 // If there was no data, then return nil 146 if len(payload.Data) == 0 { 147 return nil, nil 148 } 149 150 return payload, nil 151 } 152 153 func (c *RemoteClient) Put(data []byte) error { 154 contentType := "application/json" 155 contentLength := int64(len(data)) 156 157 i := &s3.PutObjectInput{ 158 ContentType: &contentType, 159 ContentLength: &contentLength, 160 Body: bytes.NewReader(data), 161 Bucket: &c.bucketName, 162 Key: &c.path, 163 } 164 165 if c.serverSideEncryption { 166 if c.kmsKeyID != "" { 167 i.SSEKMSKeyId = &c.kmsKeyID 168 i.ServerSideEncryption = aws.String("aws:kms") 169 } else if c.customerEncryptionKey != nil { 170 i.SetSSECustomerKey(string(c.customerEncryptionKey)) 171 i.SetSSECustomerAlgorithm(s3EncryptionAlgorithm) 172 i.SetSSECustomerKeyMD5(c.getSSECustomerKeyMD5()) 173 } else { 174 i.ServerSideEncryption = aws.String(s3EncryptionAlgorithm) 175 } 176 } 177 178 if c.acl != "" { 179 i.ACL = aws.String(c.acl) 180 } 181 182 log.Printf("[DEBUG] Uploading remote state to S3: %#v", i) 183 184 _, err := c.s3Client.PutObject(i) 185 if err != nil { 186 return fmt.Errorf("failed to upload state: %s", err) 187 } 188 189 sum := md5.Sum(data) 190 if err := c.putMD5(sum[:]); err != nil { 191 // if this errors out, we unfortunately have to error out altogether, 192 // since the next Get will inevitably fail. 193 return fmt.Errorf("failed to store state MD5: %s", err) 194 195 } 196 197 return nil 198 } 199 200 func (c *RemoteClient) Delete() error { 201 _, err := c.s3Client.DeleteObject(&s3.DeleteObjectInput{ 202 Bucket: &c.bucketName, 203 Key: &c.path, 204 }) 205 206 if err != nil { 207 return err 208 } 209 210 if err := c.deleteMD5(); err != nil { 211 log.Printf("error deleting state md5: %s", err) 212 } 213 214 return nil 215 } 216 217 func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { 218 if c.ddbTable == "" { 219 return "", nil 220 } 221 222 info.Path = c.lockPath() 223 224 if info.ID == "" { 225 lockID, err := uuid.GenerateUUID() 226 if err != nil { 227 return "", err 228 } 229 230 info.ID = lockID 231 } 232 233 putParams := &dynamodb.PutItemInput{ 234 Item: map[string]*dynamodb.AttributeValue{ 235 "LockID": {S: aws.String(c.lockPath())}, 236 "Info": {S: aws.String(string(info.Marshal()))}, 237 }, 238 TableName: aws.String(c.ddbTable), 239 ConditionExpression: aws.String("attribute_not_exists(LockID)"), 240 } 241 _, err := c.dynClient.PutItem(putParams) 242 243 if err != nil { 244 lockInfo, infoErr := c.getLockInfo() 245 if infoErr != nil { 246 err = multierror.Append(err, infoErr) 247 } 248 249 lockErr := &statemgr.LockError{ 250 Err: err, 251 Info: lockInfo, 252 } 253 return "", lockErr 254 } 255 256 return info.ID, nil 257 } 258 259 func (c *RemoteClient) getMD5() ([]byte, error) { 260 if c.ddbTable == "" { 261 return nil, nil 262 } 263 264 getParams := &dynamodb.GetItemInput{ 265 Key: map[string]*dynamodb.AttributeValue{ 266 "LockID": {S: aws.String(c.lockPath() + stateIDSuffix)}, 267 }, 268 ProjectionExpression: aws.String("LockID, Digest"), 269 TableName: aws.String(c.ddbTable), 270 ConsistentRead: aws.Bool(true), 271 } 272 273 resp, err := c.dynClient.GetItem(getParams) 274 if err != nil { 275 return nil, err 276 } 277 278 var val string 279 if v, ok := resp.Item["Digest"]; ok && v.S != nil { 280 val = *v.S 281 } 282 283 sum, err := hex.DecodeString(val) 284 if err != nil || len(sum) != md5.Size { 285 return nil, errors.New("invalid md5") 286 } 287 288 return sum, nil 289 } 290 291 // store the hash of the state so that clients can check for stale state files. 292 func (c *RemoteClient) putMD5(sum []byte) error { 293 if c.ddbTable == "" { 294 return nil 295 } 296 297 if len(sum) != md5.Size { 298 return errors.New("invalid payload md5") 299 } 300 301 putParams := &dynamodb.PutItemInput{ 302 Item: map[string]*dynamodb.AttributeValue{ 303 "LockID": {S: aws.String(c.lockPath() + stateIDSuffix)}, 304 "Digest": {S: aws.String(hex.EncodeToString(sum))}, 305 }, 306 TableName: aws.String(c.ddbTable), 307 } 308 _, err := c.dynClient.PutItem(putParams) 309 if err != nil { 310 log.Printf("[WARN] failed to record state serial in dynamodb: %s", err) 311 } 312 313 return nil 314 } 315 316 // remove the hash value for a deleted state 317 func (c *RemoteClient) deleteMD5() error { 318 if c.ddbTable == "" { 319 return nil 320 } 321 322 params := &dynamodb.DeleteItemInput{ 323 Key: map[string]*dynamodb.AttributeValue{ 324 "LockID": {S: aws.String(c.lockPath() + stateIDSuffix)}, 325 }, 326 TableName: aws.String(c.ddbTable), 327 } 328 if _, err := c.dynClient.DeleteItem(params); err != nil { 329 return err 330 } 331 return nil 332 } 333 334 func (c *RemoteClient) getLockInfo() (*statemgr.LockInfo, error) { 335 getParams := &dynamodb.GetItemInput{ 336 Key: map[string]*dynamodb.AttributeValue{ 337 "LockID": {S: aws.String(c.lockPath())}, 338 }, 339 ProjectionExpression: aws.String("LockID, Info"), 340 TableName: aws.String(c.ddbTable), 341 ConsistentRead: aws.Bool(true), 342 } 343 344 resp, err := c.dynClient.GetItem(getParams) 345 if err != nil { 346 return nil, err 347 } 348 349 var infoData string 350 if v, ok := resp.Item["Info"]; ok && v.S != nil { 351 infoData = *v.S 352 } 353 354 lockInfo := &statemgr.LockInfo{} 355 err = json.Unmarshal([]byte(infoData), lockInfo) 356 if err != nil { 357 return nil, err 358 } 359 360 return lockInfo, nil 361 } 362 363 func (c *RemoteClient) Unlock(id string) error { 364 if c.ddbTable == "" { 365 return nil 366 } 367 368 lockErr := &statemgr.LockError{} 369 370 // TODO: store the path and lock ID in separate fields, and have proper 371 // projection expression only delete the lock if both match, rather than 372 // checking the ID from the info field first. 373 lockInfo, err := c.getLockInfo() 374 if err != nil { 375 lockErr.Err = fmt.Errorf("failed to retrieve lock info: %s", err) 376 return lockErr 377 } 378 lockErr.Info = lockInfo 379 380 if lockInfo.ID != id { 381 lockErr.Err = fmt.Errorf("lock id %q does not match existing lock", id) 382 return lockErr 383 } 384 385 params := &dynamodb.DeleteItemInput{ 386 Key: map[string]*dynamodb.AttributeValue{ 387 "LockID": {S: aws.String(c.lockPath())}, 388 }, 389 TableName: aws.String(c.ddbTable), 390 } 391 _, err = c.dynClient.DeleteItem(params) 392 393 if err != nil { 394 lockErr.Err = err 395 return lockErr 396 } 397 return nil 398 } 399 400 func (c *RemoteClient) lockPath() string { 401 return fmt.Sprintf("%s/%s", c.bucketName, c.path) 402 } 403 404 func (c *RemoteClient) getSSECustomerKeyMD5() string { 405 b := md5.Sum(c.customerEncryptionKey) 406 return base64.StdEncoding.EncodeToString(b[:]) 407 } 408 409 const errBadChecksumFmt = `state data in S3 does not have the expected content. 410 411 This may be caused by unusually long delays in S3 processing a previous state 412 update. Please wait for a minute or two and try again. If this problem 413 persists, and neither S3 nor DynamoDB are experiencing an outage, you may need 414 to manually verify the remote state and update the Digest value stored in the 415 DynamoDB table to the following value: %x 416 ` 417 418 const errS3NoSuchBucket = `S3 bucket does not exist. 419 420 The referenced S3 bucket must have been previously created. If the S3 bucket 421 was created within the last minute, please wait for a minute or two and try 422 again. 423 424 Error: %s 425 `