github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/block/s3/client_cache.go (about)

     1  package s3
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  
     7  	"github.com/aws/aws-sdk-go-v2/aws"
     8  	"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
     9  	"github.com/aws/aws-sdk-go-v2/service/s3"
    10  	"github.com/treeverse/lakefs/pkg/block/params"
    11  	"github.com/treeverse/lakefs/pkg/logging"
    12  	"github.com/treeverse/lakefs/pkg/stats"
    13  )
    14  
    15  type (
    16  	clientFactory  func(region string) *s3.Client
    17  	s3RegionGetter func(ctx context.Context, bucket string) (string, error)
    18  )
    19  
    20  type ClientCache struct {
    21  	mu             sync.Mutex
    22  	regionClient   map[string]*s3.Client
    23  	bucketRegion   map[string]string
    24  	awsConfig      aws.Config
    25  	defaultClient  *s3.Client
    26  	clientFactory  clientFactory
    27  	s3RegionGetter s3RegionGetter
    28  	collector      stats.Collector
    29  }
    30  
    31  func NewClientCache(awsConfig aws.Config, params params.S3) *ClientCache {
    32  	clientFactory := newClientFactory(awsConfig, WithClientParams(params))
    33  	defaultClient := clientFactory(awsConfig.Region)
    34  	clientCache := &ClientCache{
    35  		regionClient:  make(map[string]*s3.Client),
    36  		bucketRegion:  make(map[string]string),
    37  		awsConfig:     awsConfig,
    38  		defaultClient: defaultClient,
    39  		clientFactory: clientFactory,
    40  		collector:     &stats.NullCollector{},
    41  	}
    42  	clientCache.DiscoverBucketRegion(true)
    43  	return clientCache
    44  }
    45  
    46  // newClientFactory returns a function that creates a new S3 client with the given region.
    47  // accepts aws configuration and list of s3 options functions to apply with the s3 client.
    48  // the factory function is used to create a new client for a region when it is not cached.
    49  func newClientFactory(awsConfig aws.Config, s3OptFns ...func(options *s3.Options)) clientFactory {
    50  	return func(region string) *s3.Client {
    51  		return s3.NewFromConfig(awsConfig, func(options *s3.Options) {
    52  			for _, opts := range s3OptFns {
    53  				opts(options)
    54  			}
    55  			options.Region = region
    56  		})
    57  	}
    58  }
    59  
    60  func (c *ClientCache) SetClientFactory(clientFactory clientFactory) {
    61  	c.clientFactory = clientFactory
    62  }
    63  
    64  func (c *ClientCache) SetS3RegionGetter(s3RegionGetter s3RegionGetter) {
    65  	c.s3RegionGetter = s3RegionGetter
    66  }
    67  
    68  func (c *ClientCache) SetStatsCollector(statsCollector stats.Collector) {
    69  	c.collector = statsCollector
    70  }
    71  
    72  func (c *ClientCache) DiscoverBucketRegion(b bool) {
    73  	if b {
    74  		c.s3RegionGetter = c.getBucketRegionFromAWS
    75  	} else {
    76  		c.s3RegionGetter = c.getBucketRegionDefault
    77  	}
    78  }
    79  
    80  func (c *ClientCache) getBucketRegionFromAWS(ctx context.Context, bucket string) (string, error) {
    81  	return manager.GetBucketRegion(ctx, c.defaultClient, bucket)
    82  }
    83  
    84  func (c *ClientCache) getBucketRegionDefault(_ context.Context, _ string) (string, error) {
    85  	return c.awsConfig.Region, nil
    86  }
    87  
    88  func (c *ClientCache) Get(ctx context.Context, bucket string) *s3.Client {
    89  	client, region := c.cachedClientByBucket(bucket)
    90  	if client != nil {
    91  		return client
    92  	}
    93  
    94  	// lookup region if needed
    95  	if region == "" {
    96  		region = c.refreshBucketRegion(ctx, bucket)
    97  		if client, ok := c.cachedClientByRegion(region); ok {
    98  			return client
    99  		}
   100  	}
   101  
   102  	// create client and update cache
   103  	logging.FromContext(ctx).WithField("region", region).Debug("creating client for region")
   104  	client = c.clientFactory(region)
   105  
   106  	// re-check if a client was created by another goroutine
   107  	// keep using the existing client and discard the new one
   108  	c.mu.Lock()
   109  	existingClient, existingFound := c.regionClient[region]
   110  	if existingFound {
   111  		client = existingClient
   112  	} else {
   113  		c.regionClient[region] = client
   114  	}
   115  	c.mu.Unlock()
   116  
   117  	// report client creation, if needed
   118  	if !existingFound && c.collector != nil {
   119  		c.collector.CollectEvent(stats.Event{
   120  			Class: "s3_block_adapter",
   121  			Name:  "created_aws_client_" + region,
   122  		})
   123  	}
   124  	return client
   125  }
   126  
   127  func (c *ClientCache) cachedClientByBucket(bucket string) (*s3.Client, string) {
   128  	c.mu.Lock()
   129  	defer c.mu.Unlock()
   130  	if region, ok := c.bucketRegion[bucket]; ok {
   131  		return c.regionClient[region], region
   132  	}
   133  	return nil, ""
   134  }
   135  
   136  func (c *ClientCache) cachedClientByRegion(region string) (*s3.Client, bool) {
   137  	c.mu.Lock()
   138  	defer c.mu.Unlock()
   139  	client, ok := c.regionClient[region]
   140  	return client, ok
   141  }
   142  
   143  func (c *ClientCache) refreshBucketRegion(ctx context.Context, bucket string) string {
   144  	region, err := c.s3RegionGetter(ctx, bucket)
   145  	if err != nil {
   146  		// fallback to default region
   147  		region = c.awsConfig.Region
   148  		logging.FromContext(ctx).
   149  			WithError(err).
   150  			WithField("default_region", region).
   151  			Error("Failed to get region for bucket, falling back to default region")
   152  	}
   153  	// update bucket to region cache
   154  	c.mu.Lock()
   155  	c.bucketRegion[bucket] = region
   156  	c.mu.Unlock()
   157  	return region
   158  }
   159  
   160  func (c *ClientCache) GetDefault() *s3.Client {
   161  	return c.defaultClient
   162  }