github.com/grailbio/base@v0.0.11/file/s3file/session_provider.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package s3file
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"runtime"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/aws/aws-sdk-go/aws"
    16  	"github.com/aws/aws-sdk-go/aws/session"
    17  	"github.com/aws/aws-sdk-go/service/s3"
    18  	"github.com/aws/aws-sdk-go/service/s3/s3iface"
    19  	"github.com/grailbio/base/errors"
    20  )
    21  
    22  const (
    23  	defaultRegion                        = "us-west-2"
    24  	clientCacheGarbageCollectionInterval = 10 * time.Minute
    25  )
    26  
    27  type (
    28  	// SessionProvider provides Sessions for making AWS API calls. Get() is called whenever s3file
    29  	// needs to access a file. The provider should cache and reuse the sessions, if needed.
    30  	// The implementation must be thread safe.
    31  	SessionProvider interface {
    32  		// Get returns AWS sessions that can be used to perform in.S3IAMAction on
    33  		// s3://{in.bucket}/{in.key}.
    34  		//
    35  		// s3file maintains an internal cache keyed by *session.Session that is only pruned
    36  		// occasionally. Get() is called for every S3 operation so it should be very fast. Caching
    37  		// (that is, reusing *session.Session whenever possible) is strongly encouraged.
    38  		//
    39  		// Get() must return >= 1 session, or error. If > 1, the S3 operation will be tried
    40  		// on each session in unspecified order until it succeeds.
    41  		//
    42  		// Note: Some implementations will not need SessionProviderInput and can just ignore it.
    43  		//
    44  		// TODO: Consider passing chan<- *session.Session (implementer sends and then closes)
    45  		// so s3file can try credentials as soon as they're available.
    46  		Get(_ context.Context, in SessionProviderInput) ([]*session.Session, error)
    47  	}
    48  	SessionProviderInput struct {
    49  		// S3IAMAction is an action name from this list:
    50  		// https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazons3.html
    51  		//
    52  		// Note: There is no `s3:` prefix.
    53  		//
    54  		// Note: This is different from the notion of "action" in the S3 API documentation:
    55  		// https://docs.aws.amazon.com/AmazonS3/latest/API/API_Operations.html
    56  		// Some names, like GetObject, appear in both; others, like HeadObject, do not.
    57  		S3IAMAction string
    58  		// Bucket and Key describe the API operation to be performed, if applicable.
    59  		Bucket, Key string
    60  	}
    61  
    62  	constSessionProvider struct {
    63  		session *session.Session
    64  		err     error
    65  	}
    66  )
    67  
    68  // NewDefaultProvider returns a SessionProvider that calls session.NewSession(configs...) once.
    69  func NewDefaultProvider(configs ...*aws.Config) SessionProvider {
    70  	session, err := session.NewSession(configs...)
    71  	return constSessionProvider{session, err}
    72  }
    73  
    74  func (p constSessionProvider) Get(context.Context, SessionProviderInput) ([]*session.Session, error) {
    75  	if p.err != nil {
    76  		return nil, p.err
    77  	}
    78  	return []*session.Session{p.session}, nil
    79  }
    80  
    81  type (
    82  	clientsForActionFunc func(ctx context.Context, s3IAMAction, bucket, key string) ([]s3iface.S3API, error)
    83  	// clientCache caches clients for all regions, based on the user's SessionProvider.
    84  	clientCache struct {
    85  		provider SessionProvider
    86  		// clients maps clientCacheKey -> *clientCacheValue.
    87  		// TODO: Implement some kind of garbage collection and relax the documented constraint
    88  		// that sessions are never released.
    89  		clients *sync.Map
    90  	}
    91  	clientCacheKey struct {
    92  		region string
    93  		// userSession is the session that the user's SessionProvider returned.
    94  		// It may be configured for a different region, so we don't use it directly.
    95  		userSession *session.Session
    96  	}
    97  	clientCacheValue struct {
    98  		client *s3.S3
    99  		// usedSinceLastGC is 0 or 1. It's set when this client is used, and acted on by the
   100  		// GC goroutine.
   101  		// TODO: Use atomic.Bool in go1.19.
   102  		usedSinceLastGC int32
   103  	}
   104  )
   105  
   106  func newClientCache(provider SessionProvider) *clientCache {
   107  	// According to time.Tick documentation, ticker.Stop must be called to avoid leaking ticker
   108  	// memory. However, *clientCache is never explicitly "shut down", so we don't have a good way
   109  	// to stop the GC loop. Instead, we use a finalizer on *clientCache, and ensure the GC loop
   110  	// itself doesn't keep *clientCache alive.
   111  	var (
   112  		clients         sync.Map
   113  		gcCtx, gcCancel = context.WithCancel(context.Background())
   114  	)
   115  	go func() {
   116  		ticker := time.NewTicker(clientCacheGarbageCollectionInterval)
   117  		defer ticker.Stop()
   118  		for {
   119  			select {
   120  			case <-gcCtx.Done():
   121  				return
   122  			case <-ticker.C:
   123  			}
   124  			clients.Range(func(keyAny, valueAny any) bool {
   125  				key := keyAny.(clientCacheKey)
   126  				value := valueAny.(*clientCacheValue)
   127  				if atomic.SwapInt32(&value.usedSinceLastGC, 0) == 0 {
   128  					// Note: Concurrent goroutines could mark this client as used between our query
   129  					// and delete. That's fine; we'll just construct a new client next time.
   130  					clients.Delete(key)
   131  				}
   132  				return true
   133  			})
   134  		}
   135  	}()
   136  	// Note: Declare *clientCache after the GC loop to help ensure the latter doesn't keep a
   137  	// reference to the former.
   138  	cc := clientCache{provider, &clients}
   139  	runtime.SetFinalizer(&cc, func(any) { gcCancel() })
   140  	return &cc
   141  }
   142  
   143  func (c *clientCache) forAction(ctx context.Context, s3IAMAction, bucket, key string) ([]s3iface.S3API, error) {
   144  	// TODO: Consider using some better default, like current region if we're in EC2.
   145  	region := defaultRegion
   146  	if bucket != "" { // bucket is empty when listing buckets, for example.
   147  		var err error
   148  		region, err = FindBucketRegion(ctx, bucket)
   149  		if err != nil {
   150  			return nil, errors.E(err, fmt.Sprintf("locating region for bucket %s", bucket))
   151  		}
   152  	}
   153  	sessions, err := c.provider.Get(ctx, SessionProviderInput{S3IAMAction: s3IAMAction, Bucket: bucket, Key: key})
   154  	if err != nil {
   155  		return nil, errors.E(err, fmt.Sprintf("getting sessions from provider %T", c.provider))
   156  	}
   157  	clients := make([]s3iface.S3API, len(sessions))
   158  	for i, session := range sessions {
   159  		key := clientCacheKey{region, session}
   160  		obj, ok := c.clients.Load(key)
   161  		if !ok {
   162  			obj, _ = c.clients.LoadOrStore(key, &clientCacheValue{
   163  				client:          s3.New(session, &aws.Config{Region: &region}),
   164  				usedSinceLastGC: 1,
   165  			})
   166  		}
   167  		value := obj.(*clientCacheValue)
   168  		clients[i] = value.client
   169  		atomic.StoreInt32(&value.usedSinceLastGC, 1)
   170  	}
   171  	return clients, nil
   172  }