github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/file/s3file/list_bucket.go (about) 1 package s3file 2 3 import ( 4 "context" 5 "fmt" 6 "sort" 7 "sync" 8 9 "github.com/aws/aws-sdk-go/service/s3" 10 "github.com/aws/aws-sdk-go/service/s3/s3iface" 11 "github.com/Schaudge/grailbase/errors" 12 "github.com/Schaudge/grailbase/file" 13 "github.com/Schaudge/grailbase/log" 14 "github.com/Schaudge/grailbase/traverse" 15 ) 16 17 type s3BucketLister struct { 18 ctx context.Context 19 clients []s3iface.S3API 20 scheme string 21 22 err error 23 listed bool 24 bucket string 25 buckets []string 26 } 27 28 func (l *s3BucketLister) Scan() bool { 29 if !l.listed { 30 l.buckets, l.err = combineClientBuckets(l.ctx, l.clients) 31 l.listed = true 32 } 33 if l.err != nil || len(l.buckets) == 0 { 34 return false 35 } 36 l.bucket, l.buckets = l.buckets[0], l.buckets[1:] 37 return true 38 } 39 40 // combineClientBuckets returns the union of buckets from each client, since each may have 41 // different permissions. 42 func combineClientBuckets(ctx context.Context, clients []s3iface.S3API) ([]string, error) { 43 var ( 44 uniqueBucketsMu sync.Mutex 45 uniqueBuckets = map[string]struct{}{} 46 ) 47 err := traverse.Parallel.Each(len(clients), func(clientIdx int) error { 48 buckets, err := listClientBuckets(ctx, clients[clientIdx]) 49 if err != nil { 50 if errors.Is(errors.NotAllowed, err) { 51 log.Debug.Printf("s3file.listbuckets: ignoring: %v", err) 52 return nil 53 } 54 return err 55 } 56 uniqueBucketsMu.Lock() 57 defer uniqueBucketsMu.Unlock() 58 for _, bucket := range buckets { 59 uniqueBuckets[bucket] = struct{}{} 60 } 61 return nil 62 }) 63 if err != nil { 64 return nil, err 65 } 66 buckets := make([]string, 0, len(uniqueBuckets)) 67 for bucket := range uniqueBuckets { 68 buckets = append(buckets, bucket) 69 } 70 sort.Strings(buckets) 71 return buckets, nil 72 } 73 74 func listClientBuckets(ctx context.Context, client s3iface.S3API) ([]string, error) { 75 policy := newBackoffPolicy([]s3iface.S3API{client}, file.Opts{}) 76 for { 77 var ids s3RequestIDs 78 res, err := policy.client().ListBucketsWithContext(ctx, &s3.ListBucketsInput{}, ids.captureOption()) 79 if policy.shouldRetry(ctx, err, "listbuckets") { 80 continue 81 } 82 if err != nil { 83 return nil, annotate(err, ids, &policy, "s3file.listbuckets") 84 } 85 buckets := make([]string, len(res.Buckets)) 86 for i, bucket := range res.Buckets { 87 buckets[i] = *bucket.Name 88 } 89 return buckets, nil 90 } 91 } 92 93 func (l *s3BucketLister) Path() string { 94 return fmt.Sprintf("%s://%s", l.scheme, l.bucket) 95 } 96 97 func (l *s3BucketLister) Info() file.Info { return nil } 98 99 func (l *s3BucketLister) IsDir() bool { 100 return true 101 } 102 103 // Err returns an error, if any. 104 func (l *s3BucketLister) Err() error { 105 return l.err 106 }