github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/file/s3file/list_bucket.go (about)

     1  package s3file
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sort"
     7  	"sync"
     8  
     9  	"github.com/aws/aws-sdk-go/service/s3"
    10  	"github.com/aws/aws-sdk-go/service/s3/s3iface"
    11  	"github.com/Schaudge/grailbase/errors"
    12  	"github.com/Schaudge/grailbase/file"
    13  	"github.com/Schaudge/grailbase/log"
    14  	"github.com/Schaudge/grailbase/traverse"
    15  )
    16  
    17  type s3BucketLister struct {
    18  	ctx     context.Context
    19  	clients []s3iface.S3API
    20  	scheme  string
    21  
    22  	err     error
    23  	listed  bool
    24  	bucket  string
    25  	buckets []string
    26  }
    27  
    28  func (l *s3BucketLister) Scan() bool {
    29  	if !l.listed {
    30  		l.buckets, l.err = combineClientBuckets(l.ctx, l.clients)
    31  		l.listed = true
    32  	}
    33  	if l.err != nil || len(l.buckets) == 0 {
    34  		return false
    35  	}
    36  	l.bucket, l.buckets = l.buckets[0], l.buckets[1:]
    37  	return true
    38  }
    39  
    40  // combineClientBuckets returns the union of buckets from each client, since each may have
    41  // different permissions.
    42  func combineClientBuckets(ctx context.Context, clients []s3iface.S3API) ([]string, error) {
    43  	var (
    44  		uniqueBucketsMu sync.Mutex
    45  		uniqueBuckets   = map[string]struct{}{}
    46  	)
    47  	err := traverse.Parallel.Each(len(clients), func(clientIdx int) error {
    48  		buckets, err := listClientBuckets(ctx, clients[clientIdx])
    49  		if err != nil {
    50  			if errors.Is(errors.NotAllowed, err) {
    51  				log.Debug.Printf("s3file.listbuckets: ignoring: %v", err)
    52  				return nil
    53  			}
    54  			return err
    55  		}
    56  		uniqueBucketsMu.Lock()
    57  		defer uniqueBucketsMu.Unlock()
    58  		for _, bucket := range buckets {
    59  			uniqueBuckets[bucket] = struct{}{}
    60  		}
    61  		return nil
    62  	})
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	buckets := make([]string, 0, len(uniqueBuckets))
    67  	for bucket := range uniqueBuckets {
    68  		buckets = append(buckets, bucket)
    69  	}
    70  	sort.Strings(buckets)
    71  	return buckets, nil
    72  }
    73  
    74  func listClientBuckets(ctx context.Context, client s3iface.S3API) ([]string, error) {
    75  	policy := newBackoffPolicy([]s3iface.S3API{client}, file.Opts{})
    76  	for {
    77  		var ids s3RequestIDs
    78  		res, err := policy.client().ListBucketsWithContext(ctx, &s3.ListBucketsInput{}, ids.captureOption())
    79  		if policy.shouldRetry(ctx, err, "listbuckets") {
    80  			continue
    81  		}
    82  		if err != nil {
    83  			return nil, annotate(err, ids, &policy, "s3file.listbuckets")
    84  		}
    85  		buckets := make([]string, len(res.Buckets))
    86  		for i, bucket := range res.Buckets {
    87  			buckets[i] = *bucket.Name
    88  		}
    89  		return buckets, nil
    90  	}
    91  }
    92  
    93  func (l *s3BucketLister) Path() string {
    94  	return fmt.Sprintf("%s://%s", l.scheme, l.bucket)
    95  }
    96  
    97  func (l *s3BucketLister) Info() file.Info { return nil }
    98  
    99  func (l *s3BucketLister) IsDir() bool {
   100  	return true
   101  }
   102  
   103  // Err returns an error, if any.
   104  func (l *s3BucketLister) Err() error {
   105  	return l.err
   106  }