github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/accessor/claims_cacher.go (about)

     1  package accessor
     2  
     3  import (
     4  	"encoding/json"
     5  	"sync"
     6  
     7  	"github.com/pf-qiu/concourse/v6/atc/db"
     8  	"github.com/golang/groupcache/lru"
     9  )
    10  
    11  type claimsCacheEntry struct {
    12  	claims db.Claims
    13  	size   int
    14  }
    15  
    16  type claimsCacher struct {
    17  	accessTokenFetcher AccessTokenFetcher
    18  	maxCacheSizeBytes  int
    19  
    20  	cache          *lru.Cache
    21  	cacheSizeBytes int
    22  	mu             sync.Mutex // lru.Cache is not safe for concurrent access
    23  }
    24  
    25  func NewClaimsCacher(
    26  	accessTokenFetcher AccessTokenFetcher,
    27  	maxCacheSizeBytes int,
    28  ) *claimsCacher {
    29  	c := &claimsCacher{
    30  		accessTokenFetcher: accessTokenFetcher,
    31  		maxCacheSizeBytes:  maxCacheSizeBytes,
    32  		cache:              lru.New(0),
    33  	}
    34  	c.cache.OnEvicted = func(_ lru.Key, value interface{}) {
    35  		entry, _ := value.(claimsCacheEntry)
    36  		c.cacheSizeBytes -= entry.size
    37  	}
    38  
    39  	return c
    40  }
    41  
    42  func (c *claimsCacher) GetAccessToken(rawToken string) (db.AccessToken, bool, error) {
    43  	c.mu.Lock()
    44  	defer c.mu.Unlock()
    45  
    46  	claims, found := c.cache.Get(rawToken)
    47  	if found {
    48  		entry, _ := claims.(claimsCacheEntry)
    49  		return db.AccessToken{Token: rawToken, Claims: entry.claims}, true, nil
    50  	}
    51  
    52  	token, found, err := c.accessTokenFetcher.GetAccessToken(rawToken)
    53  	if err != nil {
    54  		return db.AccessToken{}, false, err
    55  	}
    56  	payload, err := json.Marshal(token.Claims)
    57  	if err != nil {
    58  		return db.AccessToken{}, false, err
    59  	}
    60  	entry := claimsCacheEntry{claims: token.Claims, size: len(payload)}
    61  	c.cache.Add(rawToken, entry)
    62  	c.cacheSizeBytes += entry.size
    63  
    64  	for c.cacheSizeBytes > c.maxCacheSizeBytes && c.cache.Len() > 0 {
    65  		c.cache.RemoveOldest()
    66  	}
    67  
    68  	return token, true, nil
    69  }