github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/block/azure/client_cache.go (about)

     1  package azure
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"time"
     7  
     8  	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
     9  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
    10  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
    11  	"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
    12  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
    13  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/sas"
    14  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service"
    15  	lru "github.com/hnlq715/golang-lru"
    16  	"github.com/puzpuzpuz/xsync"
    17  	"github.com/treeverse/lakefs/pkg/block/params"
    18  	"golang.org/x/exp/slices"
    19  )
    20  
    21  const UDCCacheExpiry = time.Hour
    22  const UDCCacheWorkaroundDivider = 2
    23  
    24  type ClientCache struct {
    25  	serviceToClient   *xsync.MapOf[string, *service.Client]
    26  	containerToClient *xsync.MapOf[string, *container.Client]
    27  	// udcCache - User Delegation Credential cache used to reduce POST requests while creating pre-signed URLs
    28  	udcCache *lru.ARCCache
    29  	params   params.Azure
    30  }
    31  
    32  func NewCache(p params.Azure) (*ClientCache, error) {
    33  	l, err := lru.NewARCWithExpire(udcCacheSize, UDCCacheExpiry/UDCCacheWorkaroundDivider)
    34  	// TODO(Guys): dividing the udc cache expiry by 2 is a workaround for the fact that this package does not handle expiry correctly, we can remove this once we use https://github.com/hashicorp/golang-lru expirables
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  
    39  	return &ClientCache{
    40  		serviceToClient:   xsync.NewMapOf[*service.Client](),
    41  		containerToClient: xsync.NewMapOf[*container.Client](),
    42  		udcCache:          l,
    43  		params:            p,
    44  	}, nil
    45  }
    46  
    47  func mapKey(storageAccount, containerName string) string {
    48  	return fmt.Sprintf("%s#%s", storageAccount, containerName)
    49  }
    50  
    51  func (c *ClientCache) NewContainerClient(storageAccount, containerName string) (*container.Client, error) {
    52  	key := mapKey(storageAccount, containerName)
    53  
    54  	var err error
    55  	cl, _ := c.containerToClient.LoadOrCompute(key, func() *container.Client {
    56  		var svc *service.Client
    57  		svc, err = c.NewServiceClient(storageAccount)
    58  		if err != nil {
    59  			return nil
    60  		}
    61  		return svc.NewContainerClient(containerName)
    62  	})
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	return cl, nil
    68  }
    69  
    70  func (c *ClientCache) NewServiceClient(storageAccount string) (*service.Client, error) {
    71  	p := c.params
    72  	// Use StorageAccessKey to initialize storage account client only if it was provided for this given storage account
    73  	// Otherwise fall back to the default credentials
    74  	if p.StorageAccount != storageAccount {
    75  		p.StorageAccount = storageAccount
    76  		p.StorageAccessKey = ""
    77  	}
    78  
    79  	var err error
    80  	cl, _ := c.serviceToClient.LoadOrCompute(storageAccount, func() *service.Client {
    81  		var svc *service.Client
    82  		svc, err = BuildAzureServiceClient(p)
    83  		if err != nil {
    84  			return nil
    85  		}
    86  		return svc
    87  	})
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	return cl, nil
    93  }
    94  
    95  func (c *ClientCache) NewUDC(ctx context.Context, storageAccount string, expiry *time.Time) (*service.UserDelegationCredential, error) {
    96  	var udc *service.UserDelegationCredential
    97  	// Check udcCache
    98  	res, ok := c.udcCache.Get(storageAccount)
    99  	if !ok {
   100  		baseTime := time.Now().UTC().Add(-10 * time.Second)
   101  		// UDC expiry time of PreSignedExpiry + hour
   102  		udcExpiry := expiry.Add(UDCCacheExpiry)
   103  		info := service.KeyInfo{
   104  			Start:  to.Ptr(baseTime.UTC().Format(sas.TimeFormat)),
   105  			Expiry: to.Ptr(udcExpiry.Format(sas.TimeFormat)),
   106  		}
   107  		svc, err := c.NewServiceClient(storageAccount)
   108  		if err != nil {
   109  			return nil, err
   110  		}
   111  		udc, err = svc.GetUserDelegationCredential(ctx, info, nil)
   112  		if err != nil {
   113  			return nil, err
   114  		}
   115  		// UDC expires after PreSignedExpiry + hour but cache entry expires after an hour
   116  		c.udcCache.Add(storageAccount, udc)
   117  	} else {
   118  		udc = res.(*service.UserDelegationCredential)
   119  	}
   120  	return udc, nil
   121  }
   122  
   123  func BuildAzureServiceClient(params params.Azure) (*service.Client, error) {
   124  	var endpoint string
   125  	if params.Domain == "" {
   126  		params.Domain = BlobEndpointDefaultDomain
   127  	} else if !slices.Contains(supportedDomains, params.Domain) {
   128  		return nil, ErrInvalidDomain
   129  	}
   130  
   131  	if params.TestEndpointURL != "" { // For testing purposes - override default endpoint template
   132  		endpoint = params.TestEndpointURL
   133  	} else {
   134  		endpoint = buildAccountEndpoint(params.StorageAccount, params.Domain)
   135  	}
   136  
   137  	options := service.ClientOptions{ClientOptions: azcore.ClientOptions{Retry: policy.RetryOptions{TryTimeout: params.TryTimeout}}}
   138  	if params.StorageAccessKey != "" {
   139  		cred, err := service.NewSharedKeyCredential(params.StorageAccount, params.StorageAccessKey)
   140  		if err != nil {
   141  			return nil, fmt.Errorf("invalid credentials: %w", err)
   142  		}
   143  		return service.NewClientWithSharedKeyCredential(endpoint, cred, &options)
   144  	}
   145  
   146  	defaultCreds, err := azidentity.NewDefaultAzureCredential(nil)
   147  	if err != nil {
   148  		return nil, fmt.Errorf("missing credentials: %w", err)
   149  	}
   150  	return service.NewClient(endpoint, defaultCreds, &options)
   151  }
   152  
   153  func buildAccountEndpoint(storageAccount, domain string) string {
   154  	return fmt.Sprintf("https://%s.%s", storageAccount, domain)
   155  }