github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/block/s3/adapter.go (about) 1 package s3 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "io" 9 "net/http" 10 "net/url" 11 "strings" 12 "sync/atomic" 13 "time" 14 15 "github.com/aws/aws-sdk-go-v2/aws" 16 v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" 17 "github.com/aws/aws-sdk-go-v2/config" 18 "github.com/aws/aws-sdk-go-v2/credentials" 19 "github.com/aws/aws-sdk-go-v2/credentials/stscreds" 20 "github.com/aws/aws-sdk-go-v2/feature/s3/manager" 21 "github.com/aws/aws-sdk-go-v2/service/s3" 22 "github.com/aws/aws-sdk-go-v2/service/s3/types" 23 "github.com/aws/smithy-go/middleware" 24 smithyhttp "github.com/aws/smithy-go/transport/http" 25 "github.com/treeverse/lakefs/pkg/block" 26 "github.com/treeverse/lakefs/pkg/block/params" 27 "github.com/treeverse/lakefs/pkg/logging" 28 "github.com/treeverse/lakefs/pkg/stats" 29 ) 30 31 var ( 32 ErrS3 = errors.New("s3 error") 33 ErrMissingETag = fmt.Errorf("%w: missing ETag", ErrS3) 34 ) 35 36 type Adapter struct { 37 clients *ClientCache 38 respServer atomic.Pointer[string] 39 ServerSideEncryption string 40 ServerSideEncryptionKmsKeyID string 41 preSignedExpiry time.Duration 42 sessionExpiryWindow time.Duration 43 disablePreSigned bool 44 disablePreSignedUI bool 45 disablePreSignedMultipart bool 46 } 47 48 func WithStatsCollector(s stats.Collector) func(a *Adapter) { 49 return func(a *Adapter) { 50 a.clients.SetStatsCollector(s) 51 } 52 } 53 54 func WithDiscoverBucketRegion(b bool) func(a *Adapter) { 55 return func(a *Adapter) { 56 a.clients.DiscoverBucketRegion(b) 57 } 58 } 59 60 func WithPreSignedExpiry(v time.Duration) func(a *Adapter) { 61 return func(a *Adapter) { 62 a.preSignedExpiry = v 63 } 64 } 65 66 func WithDisablePreSigned(b bool) func(a *Adapter) { 67 return func(a *Adapter) { 68 if b { 69 a.disablePreSigned = true 70 } 71 } 72 } 73 74 func WithDisablePreSignedUI(b bool) func(a *Adapter) { 75 return func(a *Adapter) { 76 if b { 77 a.disablePreSignedUI = true 78 } 79 } 80 } 81 82 func WithDisablePreSignedMultipart(b bool) func(a *Adapter) { 83 return func(a *Adapter) { 84 if b { 85 a.disablePreSignedMultipart = true 86 } 87 } 88 } 89 90 func WithServerSideEncryption(s string) func(a *Adapter) { 91 return func(a *Adapter) { 92 a.ServerSideEncryption = s 93 } 94 } 95 96 func WithServerSideEncryptionKmsKeyID(s string) func(a *Adapter) { 97 return func(a *Adapter) { 98 a.ServerSideEncryptionKmsKeyID = s 99 } 100 } 101 102 type AdapterOption func(a *Adapter) 103 104 func NewAdapter(ctx context.Context, params params.S3, opts ...AdapterOption) (*Adapter, error) { 105 cfg, err := LoadConfig(ctx, params) 106 if err != nil { 107 return nil, err 108 } 109 var sessionExpiryWindow time.Duration 110 if params.WebIdentity != nil { 111 sessionExpiryWindow = params.WebIdentity.SessionExpiryWindow 112 } 113 a := &Adapter{ 114 clients: NewClientCache(cfg, params), 115 preSignedExpiry: block.DefaultPreSignExpiryDuration, 116 sessionExpiryWindow: sessionExpiryWindow, 117 } 118 for _, opt := range opts { 119 opt(a) 120 } 121 return a, nil 122 } 123 124 func LoadConfig(ctx context.Context, params params.S3) (aws.Config, error) { 125 var opts []func(*config.LoadOptions) error 126 127 opts = append(opts, config.WithLogger(&logging.AWSAdapter{ 128 Logger: logging.ContextUnavailable().WithField("sdk", "aws"), 129 })) 130 var logMode aws.ClientLogMode 131 if params.ClientLogRetries { 132 logMode |= aws.LogRetries 133 } 134 if params.ClientLogRequest { 135 logMode |= aws.LogRequest 136 } 137 if logMode != 0 { 138 opts = append(opts, config.WithClientLogMode(logMode)) 139 } 140 if params.Region != "" { 141 opts = append(opts, config.WithRegion(params.Region)) 142 } 143 if params.Profile != "" { 144 opts = append(opts, config.WithSharedConfigProfile(params.Profile)) 145 } 146 if params.CredentialsFile != "" { 147 opts = append(opts, config.WithSharedCredentialsFiles([]string{params.CredentialsFile})) 148 } 149 if params.Credentials.AccessKeyID != "" { 150 opts = append(opts, config.WithCredentialsProvider( 151 credentials.NewStaticCredentialsProvider( 152 params.Credentials.AccessKeyID, 153 params.Credentials.SecretAccessKey, 154 params.Credentials.SessionToken, 155 ), 156 )) 157 } 158 if params.MaxRetries > 0 { 159 opts = append(opts, config.WithRetryMaxAttempts(params.MaxRetries)) 160 } 161 if params.SkipVerifyCertificateTestOnly { 162 tr := &http.Transport{ 163 TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec 164 } 165 opts = append(opts, config.WithHTTPClient(&http.Client{Transport: tr})) 166 } 167 if params.WebIdentity != nil { 168 wi := *params.WebIdentity // Copy WebIdentity: it will be used asynchronously. 169 if wi.SessionDuration > 0 { 170 opts = append(opts, config.WithWebIdentityRoleCredentialOptions( 171 func(options *stscreds.WebIdentityRoleOptions) { 172 options.Duration = wi.SessionDuration 173 }), 174 ) 175 } 176 if wi.SessionExpiryWindow > 0 { 177 opts = append(opts, config.WithCredentialsCacheOptions( 178 func(options *aws.CredentialsCacheOptions) { 179 options.ExpiryWindow = wi.SessionExpiryWindow 180 }), 181 ) 182 } 183 } 184 return config.LoadDefaultConfig(ctx, opts...) 185 } 186 187 func WithClientParams(params params.S3) func(options *s3.Options) { 188 return func(options *s3.Options) { 189 if params.Endpoint != "" { 190 options.BaseEndpoint = aws.String(params.Endpoint) 191 } 192 if params.ForcePathStyle { 193 options.UsePathStyle = true 194 } 195 } 196 } 197 198 func (a *Adapter) log(ctx context.Context) logging.Logger { 199 return logging.FromContext(ctx) 200 } 201 202 func (a *Adapter) Put(ctx context.Context, obj block.ObjectPointer, sizeBytes int64, reader io.Reader, opts block.PutOpts) error { 203 var err error 204 defer reportMetrics("Put", time.Now(), &sizeBytes, &err) 205 206 // for unknown size, we assume we like to stream content, will use s3manager to perform the request. 207 // we assume the caller may not have 1:1 request to s3 put object in this case as it may perform multipart upload 208 if sizeBytes == -1 { 209 return a.managerUpload(ctx, obj, reader, opts) 210 } 211 212 bucket, key, _, err := a.extractParamsFromObj(obj) 213 if err != nil { 214 return err 215 } 216 217 putObject := s3.PutObjectInput{ 218 Bucket: aws.String(bucket), 219 Key: aws.String(key), 220 Body: reader, 221 ContentLength: aws.Int64(sizeBytes), 222 } 223 if sizeBytes == 0 { 224 putObject.Body = http.NoBody 225 } 226 if opts.StorageClass != nil { 227 putObject.StorageClass = types.StorageClass(*opts.StorageClass) 228 } 229 if a.ServerSideEncryption != "" { 230 putObject.ServerSideEncryption = types.ServerSideEncryption(a.ServerSideEncryption) 231 } 232 if a.ServerSideEncryptionKmsKeyID != "" { 233 putObject.SSEKMSKeyId = aws.String(a.ServerSideEncryptionKmsKeyID) 234 } 235 236 client := a.clients.Get(ctx, bucket) 237 resp, err := client.PutObject(ctx, &putObject, 238 retryMaxAttemptsByReader(reader), 239 s3.WithAPIOptions(v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware), 240 a.registerCaptureServerMiddleware(), 241 ) 242 if err != nil { 243 return err 244 } 245 etag := aws.ToString(resp.ETag) 246 if etag == "" { 247 return ErrMissingETag 248 } 249 return nil 250 } 251 252 // retryMaxAttemptsByReader return s3 options function 253 // setup RetryMaxAttempts - if the reader is not seekable, we can't retry the request 254 func retryMaxAttemptsByReader(reader io.Reader) func(*s3.Options) { 255 return func(o *s3.Options) { 256 if _, ok := reader.(io.Seeker); !ok { 257 o.RetryMaxAttempts = 1 258 } 259 } 260 } 261 262 // captureServerDeserializeMiddleware extracts the server name from the response and sets it on the block adapter 263 func (a *Adapter) captureServerDeserializeMiddleware(ctx context.Context, input middleware.DeserializeInput, handler middleware.DeserializeHandler) (middleware.DeserializeOutput, middleware.Metadata, error) { 264 output, m, err := handler.HandleDeserialize(ctx, input) 265 if err == nil { 266 if rawResponse, ok := output.RawResponse.(*smithyhttp.Response); ok { 267 s := rawResponse.Header.Get("Server") 268 if s != "" { 269 a.respServer.Store(&s) 270 } 271 } 272 } 273 return output, m, err 274 } 275 276 func (a *Adapter) UploadPart(ctx context.Context, obj block.ObjectPointer, sizeBytes int64, reader io.Reader, uploadID string, partNumber int) (*block.UploadPartResponse, error) { 277 var err error 278 defer reportMetrics("UploadPart", time.Now(), &sizeBytes, &err) 279 bucket, key, _, err := a.extractParamsFromObj(obj) 280 if err != nil { 281 return nil, err 282 } 283 284 uploadPartInput := &s3.UploadPartInput{ 285 Bucket: aws.String(bucket), 286 Key: aws.String(key), 287 PartNumber: aws.Int32(int32(partNumber)), 288 UploadId: aws.String(uploadID), 289 Body: reader, 290 ContentLength: aws.Int64(sizeBytes), 291 } 292 if a.ServerSideEncryption != "" { 293 uploadPartInput.SSECustomerAlgorithm = &a.ServerSideEncryption 294 } 295 if a.ServerSideEncryptionKmsKeyID != "" { 296 uploadPartInput.SSECustomerKey = &a.ServerSideEncryptionKmsKeyID 297 } 298 299 client := a.clients.Get(ctx, bucket) 300 resp, err := client.UploadPart(ctx, uploadPartInput, 301 retryMaxAttemptsByReader(reader), 302 s3.WithAPIOptions(v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware), 303 a.registerCaptureServerMiddleware(), 304 ) 305 if err != nil { 306 return nil, err 307 } 308 etag := aws.ToString(resp.ETag) 309 if etag == "" { 310 return nil, ErrMissingETag 311 } 312 return &block.UploadPartResponse{ 313 ETag: strings.Trim(etag, `"`), 314 ServerSideHeader: extractSSHeaderUploadPart(resp), 315 }, nil 316 } 317 318 func isErrNotFound(err error) bool { 319 var ( 320 errNoSuchKey *types.NoSuchKey 321 errNotFound *types.NotFound 322 ) 323 return errors.As(err, &errNoSuchKey) || errors.As(err, &errNotFound) 324 } 325 326 func (a *Adapter) Get(ctx context.Context, obj block.ObjectPointer, _ int64) (io.ReadCloser, error) { 327 var err error 328 var sizeBytes int64 329 defer reportMetrics("Get", time.Now(), &sizeBytes, &err) 330 log := a.log(ctx).WithField("operation", "GetObject") 331 bucket, key, qualifiedKey, err := a.extractParamsFromObj(obj) 332 if err != nil { 333 return nil, err 334 } 335 336 getObjectInput := s3.GetObjectInput{ 337 Bucket: aws.String(bucket), 338 Key: aws.String(key), 339 } 340 client := a.clients.Get(ctx, bucket) 341 objectOutput, err := client.GetObject(ctx, &getObjectInput) 342 if isErrNotFound(err) { 343 return nil, block.ErrDataNotFound 344 } 345 if err != nil { 346 log.WithError(err).Errorf("failed to get S3 object bucket %s key %s", qualifiedKey.GetStorageNamespace(), qualifiedKey.GetKey()) 347 return nil, err 348 } 349 sizeBytes = aws.ToInt64(objectOutput.ContentLength) 350 return objectOutput.Body, nil 351 } 352 353 func (a *Adapter) GetWalker(uri *url.URL) (block.Walker, error) { 354 if err := block.ValidateStorageType(uri, block.StorageTypeS3); err != nil { 355 return nil, err 356 } 357 return NewS3Walker(a.clients.GetDefault()), nil 358 } 359 360 type CaptureExpiresPresigner struct { 361 Presigner s3.HTTPPresignerV4 362 CredentialsCanExpire bool 363 CredentialsExpireAt time.Time 364 } 365 366 func (c *CaptureExpiresPresigner) PresignHTTP(ctx context.Context, credentials aws.Credentials, r *http.Request, payloadHash string, service string, region string, signingTime time.Time, optFns ...func(*v4.SignerOptions)) (url string, signedHeader http.Header, err error) { 367 // capture credentials expiry 368 c.CredentialsCanExpire = credentials.CanExpire 369 c.CredentialsExpireAt = credentials.Expires 370 return c.Presigner.PresignHTTP(ctx, credentials, r, payloadHash, service, region, signingTime, optFns...) 371 } 372 373 func (a *Adapter) GetPreSignedURL(ctx context.Context, obj block.ObjectPointer, mode block.PreSignMode) (string, time.Time, error) { 374 if a.disablePreSigned { 375 return "", time.Time{}, block.ErrOperationNotSupported 376 } 377 378 expiry := time.Now().Add(a.preSignedExpiry) 379 380 log := a.log(ctx).WithFields(logging.Fields{ 381 "operation": "GetPreSignedURL", 382 "namespace": obj.StorageNamespace, 383 "identifier": obj.Identifier, 384 "ttl": time.Until(expiry), 385 }) 386 bucket, key, _, err := a.extractParamsFromObj(obj) 387 if err != nil { 388 log.WithError(err).Error("could not resolve namespace") 389 return "", time.Time{}, err 390 } 391 392 client := a.clients.Get(ctx, bucket) 393 presigner := s3.NewPresignClient(client, 394 func(options *s3.PresignOptions) { 395 options.Expires = a.preSignedExpiry 396 }) 397 398 captureExpiresPresigner := &CaptureExpiresPresigner{} 399 var req *v4.PresignedHTTPRequest 400 if mode == block.PreSignModeWrite { 401 putObjectInput := &s3.PutObjectInput{ 402 Bucket: aws.String(bucket), 403 Key: aws.String(key), 404 } 405 req, err = presigner.PresignPutObject(ctx, putObjectInput, func(o *s3.PresignOptions) { 406 captureExpiresPresigner.Presigner = o.Presigner 407 o.Presigner = captureExpiresPresigner 408 }) 409 } else { 410 getObjectInput := &s3.GetObjectInput{ 411 Bucket: aws.String(bucket), 412 Key: aws.String(key), 413 } 414 req, err = presigner.PresignGetObject(ctx, getObjectInput, func(o *s3.PresignOptions) { 415 captureExpiresPresigner.Presigner = o.Presigner 416 o.Presigner = captureExpiresPresigner 417 }) 418 } 419 if err != nil { 420 log.WithError(err).Error("could not pre-sign request") 421 return "", time.Time{}, err 422 } 423 424 // In case the credentials can expire, we need to use the earliest expiry time 425 // we assume that session expiry window is used and adjust the expiry time accordingly. 426 // AWS Go SDK v2 stores the time to renew credentials in `CredentialsExpireAt`. This is 427 // a.sessionExpiryWindow before actual credentials expiry. 428 if captureExpiresPresigner.CredentialsCanExpire && captureExpiresPresigner.CredentialsExpireAt.Before(expiry) { 429 expiry = captureExpiresPresigner.CredentialsExpireAt.Add(a.sessionExpiryWindow) 430 } 431 return req.URL, expiry, nil 432 } 433 434 func (a *Adapter) GetPresignUploadPartURL(ctx context.Context, obj block.ObjectPointer, uploadID string, partNumber int) (string, error) { 435 if a.disablePreSigned { 436 return "", block.ErrOperationNotSupported 437 } 438 439 log := a.log(ctx).WithFields(logging.Fields{ 440 "operation": "GetPresignUploadPartURL", 441 "namespace": obj.StorageNamespace, 442 "identifier": obj.Identifier, 443 }) 444 bucket, key, _, err := a.extractParamsFromObj(obj) 445 if err != nil { 446 log.WithError(err).Error("Could not resolve namespace") 447 return "", err 448 } 449 450 client := a.clients.Get(ctx, bucket) 451 presigner := s3.NewPresignClient(client, 452 func(options *s3.PresignOptions) { 453 options.Expires = a.preSignedExpiry 454 }, 455 ) 456 457 uploadInput := &s3.UploadPartInput{ 458 Bucket: aws.String(bucket), 459 Key: aws.String(key), 460 UploadId: aws.String(uploadID), 461 PartNumber: aws.Int32(int32(partNumber)), 462 } 463 uploadPart, err := presigner.PresignUploadPart(ctx, uploadInput) 464 if err != nil { 465 return "", err 466 } 467 return uploadPart.URL, nil 468 } 469 470 func (a *Adapter) Exists(ctx context.Context, obj block.ObjectPointer) (bool, error) { 471 var err error 472 defer reportMetrics("Exists", time.Now(), nil, &err) 473 log := a.log(ctx).WithField("operation", "HeadObject") 474 bucket, key, _, err := a.extractParamsFromObj(obj) 475 if err != nil { 476 return false, err 477 } 478 479 input := s3.HeadObjectInput{ 480 Bucket: aws.String(bucket), 481 Key: aws.String(key), 482 } 483 client := a.clients.Get(ctx, bucket) 484 _, err = client.HeadObject(ctx, &input) 485 if isErrNotFound(err) { 486 return false, nil 487 } 488 if err != nil { 489 log.WithError(err).Errorf("failed to stat S3 object") 490 return false, err 491 } 492 return true, nil 493 } 494 495 func (a *Adapter) GetRange(ctx context.Context, obj block.ObjectPointer, startPosition int64, endPosition int64) (io.ReadCloser, error) { 496 var err error 497 var sizeBytes int64 498 defer reportMetrics("GetRange", time.Now(), &sizeBytes, &err) 499 bucket, key, _, err := a.extractParamsFromObj(obj) 500 if err != nil { 501 return nil, err 502 } 503 log := a.log(ctx).WithField("operation", "GetObjectRange") 504 getObjectInput := s3.GetObjectInput{ 505 Bucket: aws.String(bucket), 506 Key: aws.String(key), 507 Range: aws.String(fmt.Sprintf("bytes=%d-%d", startPosition, endPosition)), 508 } 509 client := a.clients.Get(ctx, bucket) 510 objectOutput, err := client.GetObject(ctx, &getObjectInput) 511 if isErrNotFound(err) { 512 return nil, block.ErrDataNotFound 513 } 514 if err != nil { 515 log.WithError(err).WithFields(logging.Fields{ 516 "start_position": startPosition, 517 "end_position": endPosition, 518 }).Error("failed to get S3 object range") 519 return nil, err 520 } 521 sizeBytes = aws.ToInt64(objectOutput.ContentLength) 522 return objectOutput.Body, nil 523 } 524 525 func (a *Adapter) GetProperties(ctx context.Context, obj block.ObjectPointer) (block.Properties, error) { 526 var err error 527 defer reportMetrics("GetProperties", time.Now(), nil, &err) 528 bucket, key, _, err := a.extractParamsFromObj(obj) 529 if err != nil { 530 return block.Properties{}, err 531 } 532 533 headObjectParams := &s3.HeadObjectInput{ 534 Bucket: aws.String(bucket), 535 Key: aws.String(key), 536 } 537 client := a.clients.Get(ctx, bucket) 538 s3Props, err := client.HeadObject(ctx, headObjectParams) 539 if err != nil { 540 return block.Properties{}, err 541 } 542 return block.Properties{ 543 StorageClass: aws.String(string(s3Props.StorageClass)), 544 }, nil 545 } 546 547 func (a *Adapter) Remove(ctx context.Context, obj block.ObjectPointer) error { 548 var err error 549 defer reportMetrics("Remove", time.Now(), nil, &err) 550 bucket, key, _, err := a.extractParamsFromObj(obj) 551 if err != nil { 552 return err 553 } 554 555 deleteInput := &s3.DeleteObjectInput{ 556 Bucket: aws.String(bucket), 557 Key: aws.String(key), 558 } 559 client := a.clients.Get(ctx, bucket) 560 _, err = client.DeleteObject(ctx, deleteInput) 561 if err != nil { 562 a.log(ctx).WithError(err).Error("failed to delete S3 object") 563 return err 564 } 565 566 headInput := &s3.HeadObjectInput{ 567 Bucket: aws.String(bucket), 568 Key: aws.String(key), 569 } 570 const maxWaitDur = 100 * time.Second 571 waiter := s3.NewObjectNotExistsWaiter(client) 572 return waiter.Wait(ctx, headInput, maxWaitDur) 573 } 574 575 func (a *Adapter) copyPart(ctx context.Context, sourceObj, destinationObj block.ObjectPointer, uploadID string, partNumber int, byteRange *string) (*block.UploadPartResponse, error) { 576 srcKey, err := resolveNamespace(sourceObj) 577 if err != nil { 578 return nil, err 579 } 580 581 bucket, key, _, err := a.extractParamsFromObj(destinationObj) 582 if err != nil { 583 return nil, err 584 } 585 586 uploadPartCopyObject := s3.UploadPartCopyInput{ 587 Bucket: aws.String(bucket), 588 Key: aws.String(key), 589 PartNumber: aws.Int32(int32(partNumber)), 590 UploadId: aws.String(uploadID), 591 CopySource: aws.String(fmt.Sprintf("%s/%s", srcKey.GetStorageNamespace(), srcKey.GetKey())), 592 } 593 if byteRange != nil { 594 uploadPartCopyObject.CopySourceRange = byteRange 595 } 596 client := a.clients.Get(ctx, bucket) 597 resp, err := client.UploadPartCopy(ctx, &uploadPartCopyObject) 598 if err != nil { 599 return nil, err 600 } 601 if resp == nil || resp.CopyPartResult == nil || resp.CopyPartResult.ETag == nil { 602 return nil, ErrMissingETag 603 } 604 605 etag := strings.Trim(*resp.CopyPartResult.ETag, `"`) 606 return &block.UploadPartResponse{ 607 ETag: etag, 608 ServerSideHeader: extractSSHeaderUploadPartCopy(resp), 609 }, nil 610 } 611 612 func (a *Adapter) UploadCopyPart(ctx context.Context, sourceObj, destinationObj block.ObjectPointer, uploadID string, partNumber int) (*block.UploadPartResponse, error) { 613 var err error 614 defer reportMetrics("UploadCopyPart", time.Now(), nil, &err) 615 return a.copyPart(ctx, sourceObj, destinationObj, uploadID, partNumber, nil) 616 } 617 618 func (a *Adapter) UploadCopyPartRange(ctx context.Context, sourceObj, destinationObj block.ObjectPointer, uploadID string, partNumber int, startPosition, endPosition int64) (*block.UploadPartResponse, error) { 619 var err error 620 defer reportMetrics("UploadCopyPartRange", time.Now(), nil, &err) 621 return a.copyPart(ctx, 622 sourceObj, destinationObj, uploadID, partNumber, 623 aws.String(fmt.Sprintf("bytes=%d-%d", startPosition, endPosition))) 624 } 625 626 func (a *Adapter) Copy(ctx context.Context, sourceObj, destinationObj block.ObjectPointer) error { 627 var err error 628 defer reportMetrics("Copy", time.Now(), nil, &err) 629 qualifiedSourceKey, err := resolveNamespace(sourceObj) 630 if err != nil { 631 return err 632 } 633 634 destBucket, destKey, _, err := a.extractParamsFromObj(destinationObj) 635 if err != nil { 636 return err 637 } 638 639 copyObjectInput := &s3.CopyObjectInput{ 640 Bucket: aws.String(destBucket), 641 Key: aws.String(destKey), 642 CopySource: aws.String(qualifiedSourceKey.GetStorageNamespace() + "/" + qualifiedSourceKey.GetKey()), 643 } 644 if a.ServerSideEncryption != "" { 645 copyObjectInput.ServerSideEncryption = types.ServerSideEncryption(a.ServerSideEncryption) 646 } 647 if a.ServerSideEncryptionKmsKeyID != "" { 648 copyObjectInput.SSEKMSKeyId = aws.String(a.ServerSideEncryptionKmsKeyID) 649 } 650 _, err = a.clients.Get(ctx, destBucket).CopyObject(ctx, copyObjectInput) 651 if err != nil { 652 a.log(ctx).WithError(err).Error("failed to copy S3 object") 653 } 654 return err 655 } 656 657 func (a *Adapter) CreateMultiPartUpload(ctx context.Context, obj block.ObjectPointer, _ *http.Request, opts block.CreateMultiPartUploadOpts) (*block.CreateMultiPartUploadResponse, error) { 658 var err error 659 defer reportMetrics("CreateMultiPartUpload", time.Now(), nil, &err) 660 bucket, key, qualifiedKey, err := a.extractParamsFromObj(obj) 661 if err != nil { 662 return nil, err 663 } 664 665 input := &s3.CreateMultipartUploadInput{ 666 Bucket: aws.String(bucket), 667 Key: aws.String(key), 668 ContentType: aws.String(""), 669 Expires: aws.Time(time.Now().Add(a.preSignedExpiry)), 670 } 671 if opts.StorageClass != nil { 672 input.StorageClass = types.StorageClass(*opts.StorageClass) 673 } 674 if a.ServerSideEncryption != "" { 675 input.ServerSideEncryption = types.ServerSideEncryption(a.ServerSideEncryption) 676 } 677 if a.ServerSideEncryptionKmsKeyID != "" { 678 input.SSEKMSKeyId = &a.ServerSideEncryptionKmsKeyID 679 } 680 client := a.clients.Get(ctx, bucket) 681 resp, err := client.CreateMultipartUpload(ctx, input) 682 if err != nil { 683 return nil, err 684 } 685 uploadID := aws.ToString(resp.UploadId) 686 a.log(ctx).WithFields(logging.Fields{ 687 "upload_id": uploadID, 688 "qualified_ns": qualifiedKey.GetStorageNamespace(), 689 "qualified_key": qualifiedKey.GetKey(), 690 "key": obj.Identifier, 691 }).Debug("created multipart upload") 692 return &block.CreateMultiPartUploadResponse{ 693 UploadID: uploadID, 694 ServerSideHeader: extractSSHeaderCreateMultipartUpload(resp), 695 }, err 696 } 697 698 func (a *Adapter) AbortMultiPartUpload(ctx context.Context, obj block.ObjectPointer, uploadID string) error { 699 var err error 700 defer reportMetrics("AbortMultiPartUpload", time.Now(), nil, &err) 701 bucket, key, qualifiedKey, err := a.extractParamsFromObj(obj) 702 if err != nil { 703 return err 704 } 705 input := &s3.AbortMultipartUploadInput{ 706 Bucket: aws.String(bucket), 707 Key: aws.String(key), 708 UploadId: aws.String(uploadID), 709 } 710 711 client := a.clients.Get(ctx, bucket) 712 _, err = client.AbortMultipartUpload(ctx, input) 713 lg := a.log(ctx).WithFields(logging.Fields{ 714 "upload_id": uploadID, 715 "qualified_ns": qualifiedKey.GetStorageNamespace(), 716 "qualified_key": qualifiedKey.GetKey(), 717 "key": obj.Identifier, 718 }) 719 if err != nil { 720 lg.Error("Failed to abort multipart upload") 721 return err 722 } 723 lg.Debug("aborted multipart upload") 724 return nil 725 } 726 727 func convertFromBlockMultipartUploadCompletion(multipartList *block.MultipartUploadCompletion) *types.CompletedMultipartUpload { 728 parts := make([]types.CompletedPart, 0, len(multipartList.Part)) 729 for _, p := range multipartList.Part { 730 parts = append(parts, types.CompletedPart{ 731 ETag: aws.String(p.ETag), 732 PartNumber: aws.Int32(int32(p.PartNumber)), 733 }) 734 } 735 return &types.CompletedMultipartUpload{Parts: parts} 736 } 737 738 func (a *Adapter) CompleteMultiPartUpload(ctx context.Context, obj block.ObjectPointer, uploadID string, multipartList *block.MultipartUploadCompletion) (*block.CompleteMultiPartUploadResponse, error) { 739 var err error 740 defer reportMetrics("CompleteMultiPartUpload", time.Now(), nil, &err) 741 bucket, key, qualifiedKey, err := a.extractParamsFromObj(obj) 742 if err != nil { 743 return nil, err 744 } 745 input := &s3.CompleteMultipartUploadInput{ 746 Bucket: aws.String(bucket), 747 Key: aws.String(key), 748 UploadId: aws.String(uploadID), 749 MultipartUpload: convertFromBlockMultipartUploadCompletion(multipartList), 750 } 751 lg := a.log(ctx).WithFields(logging.Fields{ 752 "upload_id": uploadID, 753 "qualified_ns": qualifiedKey.GetStorageNamespace(), 754 "qualified_key": qualifiedKey.GetKey(), 755 "key": obj.Identifier, 756 }) 757 client := a.clients.Get(ctx, bucket) 758 resp, err := client.CompleteMultipartUpload(ctx, input) 759 if err != nil { 760 lg.WithError(err).Error("CompleteMultipartUpload failed") 761 return nil, err 762 } 763 lg.Debug("completed multipart upload") 764 headInput := &s3.HeadObjectInput{Bucket: &bucket, Key: &key} 765 headResp, err := client.HeadObject(ctx, headInput) 766 if err != nil { 767 return nil, err 768 } 769 770 etag := strings.Trim(aws.ToString(resp.ETag), `"`) 771 return &block.CompleteMultiPartUploadResponse{ 772 ETag: etag, 773 ContentLength: aws.ToInt64(headResp.ContentLength), 774 ServerSideHeader: extractSSHeaderCompleteMultipartUpload(resp), 775 }, nil 776 } 777 778 func (a *Adapter) ListParts(ctx context.Context, obj block.ObjectPointer, uploadID string, opts block.ListPartsOpts) (*block.ListPartsResponse, error) { 779 var err error 780 defer reportMetrics("ListParts", time.Now(), nil, &err) 781 bucket, key, qualifiedKey, err := a.extractParamsFromObj(obj) 782 if err != nil { 783 return nil, err 784 } 785 786 input := &s3.ListPartsInput{ 787 Bucket: aws.String(bucket), 788 Key: aws.String(key), 789 UploadId: aws.String(uploadID), 790 MaxParts: opts.MaxParts, 791 PartNumberMarker: opts.PartNumberMarker, 792 } 793 794 lg := a.log(ctx).WithFields(logging.Fields{ 795 "upload_id": uploadID, 796 "qualified_ns": qualifiedKey.GetStorageNamespace(), 797 "qualified_key": qualifiedKey.GetKey(), 798 "key": obj.Identifier, 799 "max_parts": opts.MaxParts, 800 "part_number_marker": opts.PartNumberMarker, 801 }) 802 client := a.clients.Get(ctx, bucket) 803 resp, err := client.ListParts(ctx, input) 804 if err != nil { 805 lg.WithError(err).Error("ListParts failed") 806 return nil, err 807 } 808 809 partsResp := block.ListPartsResponse{ 810 NextPartNumberMarker: resp.NextPartNumberMarker, 811 IsTruncated: aws.ToBool(resp.IsTruncated), 812 Parts: make([]block.MultipartPart, len(resp.Parts)), 813 } 814 for i, part := range resp.Parts { 815 partsResp.Parts[i] = block.MultipartPart{ 816 ETag: strings.Trim(aws.ToString(part.ETag), `"`), 817 PartNumber: int(aws.ToInt32(part.PartNumber)), 818 LastModified: aws.ToTime(part.LastModified), 819 Size: aws.ToInt64(part.Size), 820 } 821 } 822 823 lg.WithField("num_parts", len(resp.Parts)).Debug("list multipart upload parts") 824 825 return &partsResp, nil 826 } 827 828 func (a *Adapter) BlockstoreType() string { 829 return block.BlockstoreTypeS3 830 } 831 832 func (a *Adapter) GetStorageNamespaceInfo() block.StorageNamespaceInfo { 833 info := block.DefaultStorageNamespaceInfo(block.BlockstoreTypeS3) 834 if a.disablePreSigned { 835 info.PreSignSupport = false 836 } 837 if !(a.disablePreSignedUI || a.disablePreSigned) { 838 info.PreSignSupportUI = true 839 } 840 if !a.disablePreSignedMultipart && info.PreSignSupport { 841 info.PreSignSupportMultipart = true 842 } 843 return info 844 } 845 846 func resolveNamespace(obj block.ObjectPointer) (block.CommonQualifiedKey, error) { 847 qualifiedKey, err := block.DefaultResolveNamespace(obj.StorageNamespace, obj.Identifier, obj.IdentifierType) 848 if err != nil { 849 return qualifiedKey, err 850 } 851 if qualifiedKey.GetStorageType() != block.StorageTypeS3 { 852 return qualifiedKey, fmt.Errorf("expected storage type s3: %w", block.ErrInvalidAddress) 853 } 854 return qualifiedKey, nil 855 } 856 857 func (a *Adapter) ResolveNamespace(storageNamespace, key string, identifierType block.IdentifierType) (block.QualifiedKey, error) { 858 return block.DefaultResolveNamespace(storageNamespace, key, identifierType) 859 } 860 861 func (a *Adapter) RuntimeStats() map[string]string { 862 respServer := aws.ToString(a.respServer.Load()) 863 if respServer == "" { 864 return nil 865 } 866 return map[string]string{ 867 "resp_server": respServer, 868 } 869 } 870 871 func (a *Adapter) managerUpload(ctx context.Context, obj block.ObjectPointer, reader io.Reader, opts block.PutOpts) error { 872 bucket, key, _, err := a.extractParamsFromObj(obj) 873 if err != nil { 874 return err 875 } 876 877 client := a.clients.Get(ctx, bucket) 878 uploader := manager.NewUploader(client) 879 input := &s3.PutObjectInput{ 880 Bucket: aws.String(bucket), 881 Key: aws.String(key), 882 Body: reader, 883 } 884 if opts.StorageClass != nil { 885 input.StorageClass = types.StorageClass(*opts.StorageClass) 886 } 887 if a.ServerSideEncryption != "" { 888 input.ServerSideEncryption = types.ServerSideEncryption(a.ServerSideEncryption) 889 } 890 if a.ServerSideEncryptionKmsKeyID != "" { 891 input.SSEKMSKeyId = aws.String(a.ServerSideEncryptionKmsKeyID) 892 } 893 894 output, err := uploader.Upload(ctx, input) 895 if err != nil { 896 return err 897 } 898 if aws.ToString(output.ETag) == "" { 899 return ErrMissingETag 900 } 901 return nil 902 } 903 904 func (a *Adapter) extractParamsFromObj(obj block.ObjectPointer) (string, string, block.QualifiedKey, error) { 905 qk, err := a.ResolveNamespace(obj.StorageNamespace, obj.Identifier, obj.IdentifierType) 906 if err != nil { 907 return "", "", nil, err 908 } 909 bucket, key := ExtractParamsFromQK(qk) 910 return bucket, key, qk, nil 911 } 912 913 func (a *Adapter) registerCaptureServerMiddleware() func(*s3.Options) { 914 fn := middleware.DeserializeMiddlewareFunc("ResponseServerValue", a.captureServerDeserializeMiddleware) 915 return s3.WithAPIOptions(func(stack *middleware.Stack) error { 916 return stack.Deserialize.Add(fn, middleware.After) 917 }) 918 } 919 920 func ExtractParamsFromQK(qk block.QualifiedKey) (string, string) { 921 bucket, prefix, _ := strings.Cut(qk.GetStorageNamespace(), "/") 922 key := qk.GetKey() 923 if len(prefix) > 0 { // Avoid situations where prefix is empty or "/" 924 key = prefix + "/" + key 925 } 926 return bucket, key 927 }