github.com/m3db/m3@v1.5.1-0.20231129193456-75a402aa583b/src/x/cache/lru_cache.go (about)

     1  // Copyright (c) 2020 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package cache
    22  
    23  import (
    24  	"container/list"
    25  	"context"
    26  	"errors"
    27  	"math"
    28  	"sync"
    29  	"time"
    30  
    31  	"github.com/uber-go/tally"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/status"
    34  )
    35  
    36  // Defaults for use with the LRU cache.
    37  const (
    38  	DefaultTTL        = time.Minute * 30
    39  	DefaultMaxEntries = 10000
    40  )
    41  
    42  // Metrics names
    43  const (
    44  	loadTimesHistogram  = "load_times"
    45  	loadAttemptsCounter = "load_attempts"
    46  	loadsCounter        = "loads"
    47  	accessCounter       = "accesses"
    48  	entriesGauge        = "entries"
    49  )
    50  
    51  // Metrics tags
    52  var (
    53  	hitsTags    = map[string]string{"status": "hit"}
    54  	missesTags  = map[string]string{"status": "miss"}
    55  	successTags = map[string]string{"status": "success"}
    56  	failureTags = map[string]string{"status": "error"}
    57  )
    58  
    59  // An UncachedError can be used to wrap an error that should not be
    60  // cached, even if the cache is caching errors. The underlying error
    61  // will be unwrapped before returning to the caller.
    62  type UncachedError struct {
    63  	Err error
    64  }
    65  
    66  // Error returns the message for the underlying error
    67  func (e UncachedError) Error() string {
    68  	return e.Err.Error()
    69  }
    70  
    71  // Unwrap unwraps the underlying error
    72  func (e UncachedError) Unwrap() error {
    73  	return e.Err
    74  }
    75  
    76  // As returns true if the caller is asking for the error as an uncached error
    77  func (e UncachedError) As(target interface{}) bool {
    78  	if uncached, ok := target.(*UncachedError); ok {
    79  		uncached.Err = e.Err
    80  		return true
    81  	}
    82  	return false
    83  }
    84  
    85  // CachedError is a wrapper that can be used to force an error to
    86  // be cached. Useful when you want to opt-in to error caching. The
    87  // underlying error will be unwrapped before returning to the caller.
    88  type CachedError struct {
    89  	Err error
    90  }
    91  
    92  // Error returns the message for the underlying error
    93  func (e CachedError) Error() string {
    94  	return e.Err.Error()
    95  }
    96  
    97  // Unwrap unwraps the underlying error
    98  func (e CachedError) Unwrap() error {
    99  	return e.Err
   100  }
   101  
   102  // As returns true if the caller is asking for the error as an cached error
   103  func (e CachedError) As(target interface{}) bool {
   104  	if cached, ok := target.(*CachedError); ok {
   105  		cached.Err = e.Err
   106  		return true
   107  	}
   108  	return false
   109  }
   110  
   111  var (
   112  	_ error = UncachedError{}
   113  	_ error = CachedError{}
   114  )
   115  
   116  var (
   117  	// ErrEntryNotFound is returned if a cache entry cannot be found.
   118  	ErrEntryNotFound = status.Error(codes.NotFound, "not found")
   119  
   120  	// ErrCacheFull is returned if we need to load an entry, but the cache is already full of entries that are loading.
   121  	ErrCacheFull = status.Error(codes.ResourceExhausted, "try again later")
   122  )
   123  
   124  // LRUOptions are the options to an LRU cache.
   125  type LRUOptions struct {
   126  	TTL                  time.Duration
   127  	InitialSize          int
   128  	MaxEntries           int
   129  	MaxConcurrency       int
   130  	CacheErrorsByDefault bool
   131  	Metrics              tally.Scope
   132  	Now                  func() time.Time
   133  }
   134  
   135  // LRU is a fixed size LRU cache supporting expiration, loading of entries that
   136  // do not exist, ability to cache negative results (e.g errors from load), and a
   137  // mechanism for preventing multiple goroutines from entering  the loader function
   138  // simultaneously for the same key, and a mechanism for restricting the amount of
   139  // total concurrency in the loader function.
   140  type LRU struct {
   141  	// TODO(mmihic): Consider striping these mutexes + the map entry so writes only
   142  	// take a lock out on a subset of the cache
   143  	mut               sync.Mutex
   144  	metrics           *lruCacheMetrics
   145  	cacheErrors       bool
   146  	maxEntries        int
   147  	ttl               time.Duration
   148  	concurrencyLeases chan struct{}
   149  	now               func() time.Time
   150  	byAccessTime      *list.List
   151  	byLoadTime        *list.List
   152  	entries           map[string]*lruCacheEntry
   153  }
   154  
   155  // NewLRU returns a new LRU with the provided options.
   156  func NewLRU(opts *LRUOptions) *LRU {
   157  	if opts == nil {
   158  		opts = &LRUOptions{}
   159  	}
   160  
   161  	ttl := opts.TTL
   162  	if ttl == 0 {
   163  		ttl = DefaultTTL
   164  	}
   165  
   166  	maxEntries := opts.MaxEntries
   167  	if maxEntries <= 0 {
   168  		maxEntries = DefaultMaxEntries
   169  	}
   170  
   171  	initialSize := opts.InitialSize
   172  	if initialSize <= 0 {
   173  		initialSize = int(math.Min(1000, float64(maxEntries)))
   174  	}
   175  
   176  	tallyScope := opts.Metrics
   177  	if tallyScope == nil {
   178  		tallyScope = tally.NoopScope
   179  	}
   180  
   181  	tallyScope = tallyScope.SubScope("lru-cache")
   182  
   183  	now := opts.Now
   184  	if now == nil {
   185  		now = time.Now
   186  	}
   187  
   188  	var concurrencyLeases chan struct{}
   189  	if opts.MaxConcurrency > 0 {
   190  		concurrencyLeases = make(chan struct{}, opts.MaxConcurrency)
   191  		for i := 0; i < opts.MaxConcurrency; i++ {
   192  			concurrencyLeases <- struct{}{}
   193  		}
   194  	}
   195  
   196  	return &LRU{
   197  		ttl:               ttl,
   198  		now:               now,
   199  		maxEntries:        maxEntries,
   200  		cacheErrors:       opts.CacheErrorsByDefault,
   201  		concurrencyLeases: concurrencyLeases,
   202  		metrics: &lruCacheMetrics{
   203  			entries:       tallyScope.Gauge(entriesGauge),
   204  			hits:          tallyScope.Tagged(hitsTags).Counter(accessCounter),
   205  			misses:        tallyScope.Tagged(missesTags).Counter(accessCounter),
   206  			loadAttempts:  tallyScope.Counter(loadAttemptsCounter),
   207  			loadSuccesses: tallyScope.Tagged(successTags).Counter(loadsCounter),
   208  			loadFailures:  tallyScope.Tagged(failureTags).Counter(loadsCounter),
   209  			loadTimes:     tallyScope.Histogram(loadTimesHistogram, tally.DefaultBuckets),
   210  		},
   211  		byAccessTime: list.New(),
   212  		byLoadTime:   list.New(),
   213  		entries:      make(map[string]*lruCacheEntry, initialSize),
   214  	}
   215  }
   216  
   217  // Put puts a value directly into the cache. Uses the default TTL.
   218  func (c *LRU) Put(key string, value interface{}) {
   219  	c.PutWithTTL(key, value, 0)
   220  }
   221  
   222  // PutWithTTL puts a value directly into the cache with a custom TTL.
   223  func (c *LRU) PutWithTTL(key string, value interface{}, ttl time.Duration) {
   224  	var expiresAt time.Time
   225  	if ttl > 0 {
   226  		expiresAt = c.now().Add(ttl)
   227  	}
   228  
   229  	c.mut.Lock()
   230  	defer c.mut.Unlock()
   231  
   232  	_, _ = c.updateCacheEntryWithLock(key, expiresAt, value, nil, true)
   233  }
   234  
   235  // Get returns the value associated with the key, optionally
   236  // loading it if it does not exist or has expired.
   237  // NB(mmihic): We pass the loader as an argument rather than
   238  // making it a property of the cache to support access specific
   239  // loading arguments which might not be bundled into the key.
   240  func (c *LRU) Get(ctx context.Context, key string, loader LoaderFunc) (interface{}, error) {
   241  	if loader == nil {
   242  		return c.GetWithTTL(ctx, key, nil)
   243  	}
   244  
   245  	return c.GetWithTTL(ctx, key, func(ctx context.Context, key string) (interface{}, time.Time, error) {
   246  		val, err := loader(ctx, key)
   247  		return val, time.Time{}, err
   248  	})
   249  }
   250  
   251  // GetWithTTL returns the value associated with the key, optionally
   252  // loading it if it does not exist or has expired, and allowing the
   253  // loader to return a TTL for the resulting value, overriding the
   254  // default TTL associated with the cache.
   255  func (c *LRU) GetWithTTL(ctx context.Context, key string, loader LoaderWithTTLFunc) (interface{}, error) {
   256  	return c.getWithTTL(ctx, key, loader)
   257  }
   258  
   259  // TryGet will simply attempt to get a key and if it does not exist and instead
   260  // of loading it if it is missing it will just return the second boolean
   261  // argument as false to indicate it is missing.
   262  func (c *LRU) TryGet(key string) (interface{}, bool) {
   263  	// Note: We want to explicitly not pass a context so that if the function
   264  	// is modified to require it that we would cause nil ptr deref (i.e.
   265  	// catch this during the change rather than at runtime cause modified
   266  	// behavior of accidentally using a non-nil background or todo context here).
   267  	// nolint: staticcheck
   268  	value, err := c.getWithTTL(nil, key, nil)
   269  	return value, err == nil
   270  }
   271  
   272  func (c *LRU) getWithTTL(
   273  	ctx context.Context,
   274  	key string,
   275  	loader LoaderWithTTLFunc,
   276  ) (interface{}, error) {
   277  	// Spin until it's either loaded or the load fails.
   278  	for {
   279  		// Inform whether we are going to use a loader or not
   280  		// to ensure correct behavior of whether to create an entry
   281  		// that will get loaded or not occurs.
   282  		getWithNoLoader := loader == nil
   283  		value, load, loadingCh, err := c.tryCached(key, getWithNoLoader)
   284  
   285  		// There was a cached error, so just return it
   286  		if err != nil {
   287  			return nil, err
   288  		}
   289  
   290  		// Someone else is loading the entry, wait for this to complete
   291  		// (or the context to end) and try to acquire again.
   292  		if loadingCh != nil {
   293  			select {
   294  			case <-ctx.Done():
   295  				return nil, ctx.Err()
   296  			case <-loadingCh:
   297  			}
   298  			continue
   299  		}
   300  
   301  		// No entry exists and no-one else is trying to load it, so we
   302  		// should try to do so (outside of the mutex lock).
   303  		if load {
   304  			if loader == nil {
   305  				return nil, ErrEntryNotFound
   306  			}
   307  
   308  			return c.tryLoad(ctx, key, loader)
   309  		}
   310  
   311  		// There is an entry and it's valid, return it.
   312  		return value, nil
   313  	}
   314  }
   315  
   316  // has checks whether the cache has the given key. Exists only to support tests.
   317  func (c *LRU) has(key string, checkExpiry bool) bool {
   318  	c.mut.Lock()
   319  	defer c.mut.Unlock()
   320  	entry, exists := c.entries[key]
   321  
   322  	if !exists {
   323  		return false
   324  	}
   325  
   326  	if checkExpiry {
   327  		return entry.loadingCh != nil || entry.expiresAt.After(c.now())
   328  	}
   329  
   330  	return true
   331  }
   332  
   333  // tryCached returns a value from the cache, or an indication of
   334  // the caller should do (return an error, load the value, wait for a concurrent
   335  // load to complete).
   336  func (c *LRU) tryCached(
   337  	key string,
   338  	getWithNoLoader bool,
   339  ) (interface{}, bool, chan struct{}, error) {
   340  	c.mut.Lock()
   341  	defer c.mut.Unlock()
   342  
   343  	entry, exists := c.entries[key]
   344  
   345  	// If a load is already in progress, tell the caller to wait for it to finish.
   346  	if exists && entry.loadingCh != nil {
   347  		return nil, false, entry.loadingCh, nil
   348  	}
   349  
   350  	// If the entry exists and has not expired, it's a hit - return it to the caller
   351  	if exists && entry.expiresAt.After(c.now()) {
   352  		c.metrics.hits.Inc(1)
   353  		c.byAccessTime.MoveToFront(entry.accessTimeElt)
   354  		return entry.value, false, nil, entry.err
   355  	}
   356  
   357  	// Otherwise we need to load it
   358  	c.metrics.misses.Inc(1)
   359  
   360  	if getWithNoLoader {
   361  		// If we're not using a loader then return entry not found
   362  		// rather than creating a loading channel since we are not trying
   363  		// to load an element we are just attempting to retrieve it if and
   364  		// only if it exists.
   365  		return nil, false, nil, ErrEntryNotFound
   366  	}
   367  
   368  	if !exists {
   369  		// The entry doesn't exist, clear enough space for it and then add it
   370  		if !c.reserveCapacity(1) {
   371  			return nil, false, nil, ErrCacheFull
   372  		}
   373  
   374  		entry = c.newEntry(key)
   375  	} else {
   376  		// The entry expired, don't consider it for eviction while we're loading
   377  		c.byAccessTime.Remove(entry.accessTimeElt)
   378  		c.byLoadTime.Remove(entry.loadTimeElt)
   379  	}
   380  
   381  	// Create a channel that other callers can block on waiting for this to complete
   382  	entry.loadingCh = make(chan struct{})
   383  	return nil, true, nil, nil
   384  }
   385  
   386  // cacheLoadComplete is called when a cache load has completed with either a value or error.
   387  func (c *LRU) cacheLoadComplete(
   388  	key string, expiresAt time.Time, value interface{}, err error,
   389  ) (interface{}, error) {
   390  	c.mut.Lock()
   391  	defer c.mut.Unlock()
   392  
   393  	if err != nil {
   394  		return c.handleCacheLoadErrorWithLock(key, expiresAt, err)
   395  	}
   396  
   397  	return c.updateCacheEntryWithLock(key, expiresAt, value, err, false)
   398  }
   399  
   400  // handleCacheLoadErrorWithLock handles the results of an error from a cache load. If
   401  // we are caching errors, updates the cache entry with the error. Otherwise
   402  // removes the cache entry and returns the (possible unwrapped) error.
   403  func (c *LRU) handleCacheLoadErrorWithLock(
   404  	key string, expiresAt time.Time, err error,
   405  ) (interface{}, error) {
   406  	// If the loader is telling us to cache this error, do so unconditionally
   407  	var cachedErr CachedError
   408  	if errors.As(err, &cachedErr) {
   409  		return c.updateCacheEntryWithLock(key, expiresAt, nil, cachedErr.Err, false)
   410  	}
   411  
   412  	// If the cache is configured to cache errors by default, do so unless
   413  	// the loader is telling us not to cache this one (e.g. it's transient)
   414  	var uncachedErr UncachedError
   415  	isUncachedError := errors.As(err, &uncachedErr)
   416  	if c.cacheErrors && !isUncachedError {
   417  		return c.updateCacheEntryWithLock(key, expiresAt, nil, err, false)
   418  	}
   419  
   420  	// Something happened during load, but we don't want to cache this - remove the entry,
   421  	// tell any blocked callers they can try again, and return the error
   422  	entry := c.entries[key]
   423  	c.remove(entry)
   424  	close(entry.loadingCh)
   425  	entry.loadingCh = nil
   426  
   427  	if isUncachedError {
   428  		return nil, uncachedErr.Err
   429  	}
   430  
   431  	return nil, err
   432  }
   433  
   434  // updateCacheEntryWithLock updates a cache entry with a new value or cached error,
   435  // and marks it as the most recently accessed and most recently loaded entry
   436  func (c *LRU) updateCacheEntryWithLock(
   437  	key string, expiresAt time.Time, value interface{}, err error, enforceLimit bool,
   438  ) (interface{}, error) {
   439  	entry := c.entries[key]
   440  	if entry == nil {
   441  		if enforceLimit && !c.reserveCapacity(1) {
   442  			// Silently skip adding the new entry if we fail to free up space for it
   443  			// (which should never be happening).
   444  			return value, err
   445  		}
   446  		entry = &lruCacheEntry{}
   447  		c.entries[key] = entry
   448  	}
   449  
   450  	entry.key, entry.value, entry.err = key, value, err
   451  
   452  	// Re-adjust expiration and mark as both most recently access and most recently used
   453  	if expiresAt.IsZero() {
   454  		expiresAt = c.now().Add(c.ttl)
   455  	}
   456  	entry.expiresAt = expiresAt
   457  
   458  	if entry.loadTimeElt == nil {
   459  		entry.loadTimeElt = c.byLoadTime.PushFront(entry)
   460  	} else {
   461  		c.byLoadTime.MoveToFront(entry.loadTimeElt)
   462  	}
   463  
   464  	if entry.accessTimeElt == nil {
   465  		entry.accessTimeElt = c.byAccessTime.PushFront(entry)
   466  	} else {
   467  		c.byAccessTime.MoveToFront(entry.accessTimeElt)
   468  	}
   469  
   470  	c.metrics.entries.Update(float64(len(c.entries)))
   471  
   472  	// Tell any other callers that we're done loading
   473  	if entry.loadingCh != nil {
   474  		close(entry.loadingCh)
   475  		entry.loadingCh = nil
   476  	}
   477  	return value, err
   478  }
   479  
   480  // reserveCapacity evicts expired and least recently used entries (that aren't loading)
   481  // until we have at least enough space for new entries.
   482  // NB(mmihic): Must be called with the cache mutex locked.
   483  func (c *LRU) reserveCapacity(n int) bool {
   484  	// Unconditionally evict all expired entries. Entries that are expired by
   485  	// reloading are not in this list, and therefore will not be evicted.
   486  	oldestElt := c.byLoadTime.Back()
   487  	for oldestElt != nil {
   488  		entry := oldestElt.Value.(*lruCacheEntry)
   489  		if entry.expiresAt.After(c.now()) {
   490  			break
   491  		}
   492  		c.remove(entry)
   493  
   494  		oldestElt = c.byLoadTime.Back()
   495  	}
   496  
   497  	// Evict any recently accessed which are not loading, until we either run out
   498  	// of entries to evict or we have enough entries.
   499  	lruElt := c.byAccessTime.Back()
   500  	for c.maxEntries-len(c.entries) < n && lruElt != nil {
   501  		c.remove(lruElt.Value.(*lruCacheEntry))
   502  
   503  		lruElt = c.byAccessTime.Back()
   504  	}
   505  
   506  	// If we couldn't create enough space, then there are too many entries loading and the cache is simply full
   507  	if c.maxEntries-len(c.entries) < n {
   508  		return false
   509  	}
   510  
   511  	return true
   512  }
   513  
   514  // load tries to load from the loader.
   515  // NB(mmihic): Must NOT be called with the cache mutex locked.
   516  func (c *LRU) tryLoad(
   517  	ctx context.Context, key string, loader LoaderWithTTLFunc,
   518  ) (interface{}, error) {
   519  	// If we're limiting overall concurrency, acquire a concurrency lease
   520  	if c.concurrencyLeases != nil {
   521  		select {
   522  		case <-ctx.Done():
   523  			return c.cacheLoadComplete(key, time.Time{}, nil, UncachedError{ctx.Err()})
   524  		case <-c.concurrencyLeases:
   525  		}
   526  
   527  		defer func() { c.concurrencyLeases <- struct{}{} }()
   528  	}
   529  
   530  	// Increment load attempts ahead of load so we have metrics for thundering herds blocked in the loader
   531  	c.metrics.loadAttempts.Inc(1)
   532  	start := c.now()
   533  	value, expiresAt, err := loader(ctx, key)
   534  	c.metrics.loadTimes.RecordDuration(c.now().Sub(start))
   535  	if err == nil {
   536  		c.metrics.loadSuccesses.Inc(1)
   537  	} else {
   538  		c.metrics.loadFailures.Inc(1)
   539  	}
   540  
   541  	return c.cacheLoadComplete(key, expiresAt, value, err)
   542  }
   543  
   544  // remove removes an entry from the cache.
   545  // NB(mmihic): Must be called with the cache mutex locked.
   546  func (c *LRU) remove(entry *lruCacheEntry) {
   547  	delete(c.entries, entry.key)
   548  	if entry.accessTimeElt != nil {
   549  		c.byAccessTime.Remove(entry.accessTimeElt)
   550  	}
   551  
   552  	if entry.loadTimeElt != nil {
   553  		c.byLoadTime.Remove(entry.loadTimeElt)
   554  	}
   555  }
   556  
   557  // newEntry creates and adds a new cache entry.
   558  // NB(mmihic): Must be called with the cache mutex locked.
   559  func (c *LRU) newEntry(key string) *lruCacheEntry {
   560  	entry := &lruCacheEntry{key: key}
   561  	c.entries[key] = entry
   562  	return entry
   563  }
   564  
   565  type lruCacheEntry struct {
   566  	key           string
   567  	accessTimeElt *list.Element
   568  	loadTimeElt   *list.Element
   569  	loadingCh     chan struct{}
   570  	expiresAt     time.Time
   571  	err           error
   572  	value         interface{}
   573  }
   574  
   575  type lruCacheMetrics struct {
   576  	entries       tally.Gauge
   577  	hits          tally.Counter
   578  	misses        tally.Counter
   579  	loadAttempts  tally.Counter
   580  	loadSuccesses tally.Counter
   581  	loadFailures  tally.Counter
   582  	loadTimes     tally.Histogram
   583  }
   584  
   585  var _ Cache = &LRU{}