github.com/m3db/m3@v1.5.0/src/x/cache/lru_cache.go (about)

     1  // Copyright (c) 2020 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package cache
    22  
    23  import (
    24  	"container/list"
    25  	"context"
    26  	"errors"
    27  	"math"
    28  	"sync"
    29  	"time"
    30  
    31  	"github.com/uber-go/tally"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/status"
    34  )
    35  
    36  // Defaults for use with the LRU cache.
    37  const (
    38  	DefaultTTL        = time.Minute * 30
    39  	DefaultMaxEntries = 10000
    40  )
    41  
    42  // Metrics names
    43  const (
    44  	loadTimesHistogram  = "load_times"
    45  	loadAttemptsCounter = "load_attempts"
    46  	loadsCounter        = "loads"
    47  	accessCounter       = "accesses"
    48  	entriesGauge        = "entries"
    49  )
    50  
    51  // Metrics tags
    52  var (
    53  	hitsTags    = map[string]string{"status": "hit"}
    54  	missesTags  = map[string]string{"status": "miss"}
    55  	successTags = map[string]string{"status": "success"}
    56  	failureTags = map[string]string{"status": "error"}
    57  )
    58  
    59  // An UncachedError can be used to wrap an error that should not be
    60  // cached, even if the cache is caching errors. The underlying error
    61  // will be unwrapped before returning to the caller.
    62  type UncachedError struct {
    63  	Err error
    64  }
    65  
    66  // Error returns the message for the underlying error
    67  func (e UncachedError) Error() string {
    68  	return e.Err.Error()
    69  }
    70  
    71  // Unwrap unwraps the underlying error
    72  func (e UncachedError) Unwrap() error {
    73  	return e.Err
    74  }
    75  
    76  // As returns true if the caller is asking for the error as an uncached error
    77  func (e UncachedError) As(target interface{}) bool {
    78  	if uncached, ok := target.(*UncachedError); ok {
    79  		uncached.Err = e.Err
    80  		return true
    81  	}
    82  	return false
    83  }
    84  
    85  // CachedError is a wrapper that can be used to force an error to
    86  // be cached. Useful when you want to opt-in to error caching. The
    87  // underlying error will be unwrapped before returning to the caller.
    88  type CachedError struct {
    89  	Err error
    90  }
    91  
    92  // Error returns the message for the underlying error
    93  func (e CachedError) Error() string {
    94  	return e.Err.Error()
    95  }
    96  
    97  // Unwrap unwraps the underlying error
    98  func (e CachedError) Unwrap() error {
    99  	return e.Err
   100  }
   101  
   102  // As returns true if the caller is asking for the error as an cached error
   103  func (e CachedError) As(target interface{}) bool {
   104  	if cached, ok := target.(*CachedError); ok {
   105  		cached.Err = e.Err
   106  		return true
   107  	}
   108  	return false
   109  }
   110  
   111  var (
   112  	_ error = UncachedError{}
   113  	_ error = CachedError{}
   114  )
   115  
   116  var (
   117  	// ErrEntryNotFound is returned if a cache entry cannot be found.
   118  	ErrEntryNotFound = status.Error(codes.NotFound, "not found")
   119  
   120  	// ErrCacheFull is returned if we need to load an entry, but the cache is already full of entries that are loading.
   121  	ErrCacheFull = status.Error(codes.ResourceExhausted, "try again later")
   122  )
   123  
   124  // LRUOptions are the options to an LRU cache.
   125  type LRUOptions struct {
   126  	TTL                  time.Duration
   127  	InitialSize          int
   128  	MaxEntries           int
   129  	MaxConcurrency       int
   130  	CacheErrorsByDefault bool
   131  	Metrics              tally.Scope
   132  	Now                  func() time.Time
   133  }
   134  
   135  // LRU is a fixed size LRU cache supporting expiration, loading of entries that
   136  // do not exist, ability to cache negative results (e.g errors from load), and a
   137  // mechanism for preventing multiple goroutines from entering  the loader function
   138  // simultaneously for the same key, and a mechanism for restricting the amount of
   139  // total concurrency in the loader function.
   140  type LRU struct {
   141  	// TODO(mmihic): Consider striping these mutexes + the map entry so writes only
   142  	// take a lock out on a subset of the cache
   143  	mut               sync.Mutex
   144  	metrics           *lruCacheMetrics
   145  	cacheErrors       bool
   146  	maxEntries        int
   147  	ttl               time.Duration
   148  	concurrencyLeases chan struct{}
   149  	now               func() time.Time
   150  	byAccessTime      *list.List
   151  	byLoadTime        *list.List
   152  	entries           map[string]*lruCacheEntry
   153  }
   154  
   155  // NewLRU returns a new LRU with the provided options.
   156  func NewLRU(opts *LRUOptions) *LRU {
   157  	if opts == nil {
   158  		opts = &LRUOptions{}
   159  	}
   160  
   161  	ttl := opts.TTL
   162  	if ttl == 0 {
   163  		ttl = DefaultTTL
   164  	}
   165  
   166  	maxEntries := opts.MaxEntries
   167  	if maxEntries <= 0 {
   168  		maxEntries = DefaultMaxEntries
   169  	}
   170  
   171  	initialSize := opts.InitialSize
   172  	if initialSize <= 0 {
   173  		initialSize = int(math.Min(1000, float64(maxEntries)))
   174  	}
   175  
   176  	tallyScope := opts.Metrics
   177  	if tallyScope == nil {
   178  		tallyScope = tally.NoopScope
   179  	}
   180  
   181  	tallyScope = tallyScope.SubScope("lru-cache")
   182  
   183  	now := opts.Now
   184  	if now == nil {
   185  		now = time.Now
   186  	}
   187  
   188  	var concurrencyLeases chan struct{}
   189  	if opts.MaxConcurrency > 0 {
   190  		concurrencyLeases = make(chan struct{}, opts.MaxConcurrency)
   191  		for i := 0; i < opts.MaxConcurrency; i++ {
   192  			concurrencyLeases <- struct{}{}
   193  		}
   194  	}
   195  
   196  	return &LRU{
   197  		ttl:               ttl,
   198  		now:               now,
   199  		maxEntries:        maxEntries,
   200  		cacheErrors:       opts.CacheErrorsByDefault,
   201  		concurrencyLeases: concurrencyLeases,
   202  		metrics: &lruCacheMetrics{
   203  			entries:       tallyScope.Gauge(entriesGauge),
   204  			hits:          tallyScope.Tagged(hitsTags).Counter(accessCounter),
   205  			misses:        tallyScope.Tagged(missesTags).Counter(accessCounter),
   206  			loadAttempts:  tallyScope.Counter(loadAttemptsCounter),
   207  			loadSuccesses: tallyScope.Tagged(successTags).Counter(loadsCounter),
   208  			loadFailures:  tallyScope.Tagged(failureTags).Counter(loadsCounter),
   209  			loadTimes:     tallyScope.Histogram(loadTimesHistogram, tally.DefaultBuckets),
   210  		},
   211  		byAccessTime: list.New(),
   212  		byLoadTime:   list.New(),
   213  		entries:      make(map[string]*lruCacheEntry, initialSize),
   214  	}
   215  }
   216  
   217  // Put puts a value directly into the cache. Uses the default TTL.
   218  func (c *LRU) Put(key string, value interface{}) {
   219  	c.PutWithTTL(key, value, 0)
   220  }
   221  
   222  // PutWithTTL puts a value directly into the cache with a custom TTL.
   223  func (c *LRU) PutWithTTL(key string, value interface{}, ttl time.Duration) {
   224  	var expiresAt time.Time
   225  	if ttl > 0 {
   226  		expiresAt = c.now().Add(ttl)
   227  	}
   228  
   229  	c.mut.Lock()
   230  	defer c.mut.Unlock()
   231  
   232  	_, _ = c.updateCacheEntryWithLock(key, expiresAt, value, nil, true)
   233  }
   234  
   235  // Get returns the value associated with the key, optionally
   236  // loading it if it does not exist or has expired.
   237  // NB(mmihic): We pass the loader as an argument rather than
   238  // making it a property of the cache to support access specific
   239  // loading arguments which might not be bundled into the key.
   240  func (c *LRU) Get(ctx context.Context, key string, loader LoaderFunc) (interface{}, error) {
   241  	return c.GetWithTTL(ctx, key, func(ctx context.Context, key string) (interface{}, time.Time, error) {
   242  		val, err := loader(ctx, key)
   243  		return val, time.Time{}, err
   244  	})
   245  }
   246  
   247  // GetWithTTL returns the value associated with the key, optionally
   248  // loading it if it does not exist or has expired, and allowing the
   249  // loader to return a TTL for the resulting value, overriding the
   250  // default TTL associated with the cache.
   251  func (c *LRU) GetWithTTL(ctx context.Context, key string, loader LoaderWithTTLFunc) (interface{}, error) {
   252  	return c.getWithTTL(ctx, key, loader)
   253  }
   254  
   255  // TryGet will simply attempt to get a key and if it does not exist and instead
   256  // of loading it if it is missing it will just return the second boolean
   257  // argument as false to indicate it is missing.
   258  func (c *LRU) TryGet(key string) (interface{}, bool) {
   259  	// Note: We want to explicitly not pass a context so that if the function
   260  	// is modified to require it that we would cause nil ptr deref (i.e.
   261  	// catch this during the change rather than at runtime cause modified
   262  	// behavior of accidentally using a non-nil background or todo context here).
   263  	// nolint: staticcheck
   264  	value, err := c.getWithTTL(nil, key, nil)
   265  	return value, err == nil
   266  }
   267  
   268  func (c *LRU) getWithTTL(
   269  	ctx context.Context,
   270  	key string,
   271  	loader LoaderWithTTLFunc,
   272  ) (interface{}, error) {
   273  	// Spin until it's either loaded or the load fails.
   274  	for {
   275  		// Inform whether we are going to use a loader or not
   276  		// to ensure correct behavior of whether to create an entry
   277  		// that will get loaded or not occurs.
   278  		getWithNoLoader := loader == nil
   279  		value, load, loadingCh, err := c.tryCached(key, getWithNoLoader)
   280  
   281  		// There was a cached error, so just return it
   282  		if err != nil {
   283  			return nil, err
   284  		}
   285  
   286  		// Someone else is loading the entry, wait for this to complete
   287  		// (or the context to end) and try to acquire again.
   288  		if loadingCh != nil {
   289  			select {
   290  			case <-ctx.Done():
   291  				return nil, ctx.Err()
   292  			case <-loadingCh:
   293  			}
   294  			continue
   295  		}
   296  
   297  		// No entry exists and no-one else is trying to load it, so we
   298  		// should try to do so (outside of the mutex lock).
   299  		if load {
   300  			if loader == nil {
   301  				return nil, ErrEntryNotFound
   302  			}
   303  
   304  			return c.tryLoad(ctx, key, loader)
   305  		}
   306  
   307  		// There is an entry and it's valid, return it.
   308  		return value, nil
   309  	}
   310  }
   311  
   312  // has checks whether the cache has the given key. Exists only to support tests.
   313  func (c *LRU) has(key string, checkExpiry bool) bool {
   314  	c.mut.Lock()
   315  	defer c.mut.Unlock()
   316  	entry, exists := c.entries[key]
   317  
   318  	if !exists {
   319  		return false
   320  	}
   321  
   322  	if checkExpiry {
   323  		return entry.loadingCh != nil || entry.expiresAt.After(c.now())
   324  	}
   325  
   326  	return true
   327  }
   328  
   329  // tryCached returns a value from the cache, or an indication of
   330  // the caller should do (return an error, load the value, wait for a concurrent
   331  // load to complete).
   332  func (c *LRU) tryCached(
   333  	key string,
   334  	getWithNoLoader bool,
   335  ) (interface{}, bool, chan struct{}, error) {
   336  	c.mut.Lock()
   337  	defer c.mut.Unlock()
   338  
   339  	entry, exists := c.entries[key]
   340  
   341  	// If a load is already in progress, tell the caller to wait for it to finish.
   342  	if exists && entry.loadingCh != nil {
   343  		return nil, false, entry.loadingCh, nil
   344  	}
   345  
   346  	// If the entry exists and has not expired, it's a hit - return it to the caller
   347  	if exists && entry.expiresAt.After(c.now()) {
   348  		c.metrics.hits.Inc(1)
   349  		c.byAccessTime.MoveToFront(entry.accessTimeElt)
   350  		return entry.value, false, nil, entry.err
   351  	}
   352  
   353  	// Otherwise we need to load it
   354  	c.metrics.misses.Inc(1)
   355  
   356  	if getWithNoLoader {
   357  		// If we're not using a loader then return entry not found
   358  		// rather than creating a loading channel since we are not trying
   359  		// to load an element we are just attempting to retrieve it if and
   360  		// only if it exists.
   361  		return nil, false, nil, ErrEntryNotFound
   362  	}
   363  
   364  	if !exists {
   365  		// The entry doesn't exist, clear enough space for it and then add it
   366  		if !c.reserveCapacity(1) {
   367  			return nil, false, nil, ErrCacheFull
   368  		}
   369  
   370  		entry = c.newEntry(key)
   371  	} else {
   372  		// The entry expired, don't consider it for eviction while we're loading
   373  		c.byAccessTime.Remove(entry.accessTimeElt)
   374  		c.byLoadTime.Remove(entry.loadTimeElt)
   375  	}
   376  
   377  	// Create a channel that other callers can block on waiting for this to complete
   378  	entry.loadingCh = make(chan struct{})
   379  	return nil, true, nil, nil
   380  }
   381  
   382  // cacheLoadComplete is called when a cache load has completed with either a value or error.
   383  func (c *LRU) cacheLoadComplete(
   384  	key string, expiresAt time.Time, value interface{}, err error,
   385  ) (interface{}, error) {
   386  	c.mut.Lock()
   387  	defer c.mut.Unlock()
   388  
   389  	if err != nil {
   390  		return c.handleCacheLoadErrorWithLock(key, expiresAt, err)
   391  	}
   392  
   393  	return c.updateCacheEntryWithLock(key, expiresAt, value, err, false)
   394  }
   395  
   396  // handleCacheLoadErrorWithLock handles the results of an error from a cache load. If
   397  // we are caching errors, updates the cache entry with the error. Otherwise
   398  // removes the cache entry and returns the (possible unwrapped) error.
   399  func (c *LRU) handleCacheLoadErrorWithLock(
   400  	key string, expiresAt time.Time, err error,
   401  ) (interface{}, error) {
   402  	// If the loader is telling us to cache this error, do so unconditionally
   403  	var cachedErr CachedError
   404  	if errors.As(err, &cachedErr) {
   405  		return c.updateCacheEntryWithLock(key, expiresAt, nil, cachedErr.Err, false)
   406  	}
   407  
   408  	// If the cache is configured to cache errors by default, do so unless
   409  	// the loader is telling us not to cache this one (e.g. it's transient)
   410  	var uncachedErr UncachedError
   411  	isUncachedError := errors.As(err, &uncachedErr)
   412  	if c.cacheErrors && !isUncachedError {
   413  		return c.updateCacheEntryWithLock(key, expiresAt, nil, err, false)
   414  	}
   415  
   416  	// Something happened during load, but we don't want to cache this - remove the entry,
   417  	// tell any blocked callers they can try again, and return the error
   418  	entry := c.entries[key]
   419  	c.remove(entry)
   420  	close(entry.loadingCh)
   421  	entry.loadingCh = nil
   422  
   423  	if isUncachedError {
   424  		return nil, uncachedErr.Err
   425  	}
   426  
   427  	return nil, err
   428  }
   429  
   430  // updateCacheEntryWithLock updates a cache entry with a new value or cached error,
   431  // and marks it as the most recently accessed and most recently loaded entry
   432  func (c *LRU) updateCacheEntryWithLock(
   433  	key string, expiresAt time.Time, value interface{}, err error, enforceLimit bool,
   434  ) (interface{}, error) {
   435  	entry := c.entries[key]
   436  	if entry == nil {
   437  		if enforceLimit && !c.reserveCapacity(1) {
   438  			// Silently skip adding the new entry if we fail to free up space for it
   439  			// (which should never be happening).
   440  			return value, err
   441  		}
   442  		entry = &lruCacheEntry{}
   443  		c.entries[key] = entry
   444  	}
   445  
   446  	entry.value, entry.err = value, err
   447  
   448  	// Re-adjust expiration and mark as both most recently access and most recently used
   449  	if expiresAt.IsZero() {
   450  		expiresAt = c.now().Add(c.ttl)
   451  	}
   452  	entry.expiresAt = expiresAt
   453  
   454  	if entry.loadTimeElt == nil {
   455  		entry.loadTimeElt = c.byLoadTime.PushFront(entry)
   456  	} else {
   457  		c.byLoadTime.MoveToFront(entry.loadTimeElt)
   458  	}
   459  
   460  	if entry.accessTimeElt == nil {
   461  		entry.accessTimeElt = c.byAccessTime.PushFront(entry)
   462  	} else {
   463  		c.byAccessTime.MoveToFront(entry.accessTimeElt)
   464  	}
   465  
   466  	c.metrics.entries.Update(float64(len(c.entries)))
   467  
   468  	// Tell any other callers that we're done loading
   469  	if entry.loadingCh != nil {
   470  		close(entry.loadingCh)
   471  		entry.loadingCh = nil
   472  	}
   473  	return value, err
   474  }
   475  
   476  // reserveCapacity evicts expired and least recently used entries (that aren't loading)
   477  // until we have at least enough space for new entries.
   478  // NB(mmihic): Must be called with the cache mutex locked.
   479  func (c *LRU) reserveCapacity(n int) bool {
   480  	// Unconditionally evict all expired entries. Entries that are expired by
   481  	// reloading are not in this list, and therefore will not be evicted.
   482  	oldestElt := c.byLoadTime.Back()
   483  	for oldestElt != nil {
   484  		entry := oldestElt.Value.(*lruCacheEntry)
   485  		if entry.expiresAt.After(c.now()) {
   486  			break
   487  		}
   488  		c.remove(entry)
   489  
   490  		oldestElt = c.byLoadTime.Back()
   491  	}
   492  
   493  	// Evict any recently accessed which are not loading, until we either run out
   494  	// of entries to evict or we have enough entries.
   495  	lruElt := c.byAccessTime.Back()
   496  	for c.maxEntries-len(c.entries) < n && lruElt != nil {
   497  		c.remove(lruElt.Value.(*lruCacheEntry))
   498  
   499  		lruElt = c.byAccessTime.Back()
   500  	}
   501  
   502  	// If we couldn't create enough space, then there are too many entries loading and the cache is simply full
   503  	if c.maxEntries-len(c.entries) < n {
   504  		return false
   505  	}
   506  
   507  	return true
   508  }
   509  
   510  // load tries to load from the loader.
   511  // NB(mmihic): Must NOT be called with the cache mutex locked.
   512  func (c *LRU) tryLoad(
   513  	ctx context.Context, key string, loader LoaderWithTTLFunc,
   514  ) (interface{}, error) {
   515  	// If we're limiting overall concurrency, acquire a concurrency lease
   516  	if c.concurrencyLeases != nil {
   517  		select {
   518  		case <-ctx.Done():
   519  			return c.cacheLoadComplete(key, time.Time{}, nil, UncachedError{ctx.Err()})
   520  		case <-c.concurrencyLeases:
   521  		}
   522  
   523  		defer func() { c.concurrencyLeases <- struct{}{} }()
   524  	}
   525  
   526  	// Increment load attempts ahead of load so we have metrics for thundering herds blocked in the loader
   527  	c.metrics.loadAttempts.Inc(1)
   528  	start := c.now()
   529  	value, expiresAt, err := loader(ctx, key)
   530  	c.metrics.loadTimes.RecordDuration(c.now().Sub(start))
   531  	if err == nil {
   532  		c.metrics.loadSuccesses.Inc(1)
   533  	} else {
   534  		c.metrics.loadFailures.Inc(1)
   535  	}
   536  
   537  	return c.cacheLoadComplete(key, expiresAt, value, err)
   538  }
   539  
   540  // remove removes an entry from the cache.
   541  // NB(mmihic): Must be called with the cache mutex locked.
   542  func (c *LRU) remove(entry *lruCacheEntry) {
   543  	delete(c.entries, entry.key)
   544  	if entry.accessTimeElt != nil {
   545  		c.byAccessTime.Remove(entry.accessTimeElt)
   546  	}
   547  
   548  	if entry.loadTimeElt != nil {
   549  		c.byLoadTime.Remove(entry.loadTimeElt)
   550  	}
   551  }
   552  
   553  // newEntry creates and adds a new cache entry.
   554  // NB(mmihic): Must be called with the cache mutex locked.
   555  func (c *LRU) newEntry(key string) *lruCacheEntry {
   556  	entry := &lruCacheEntry{key: key}
   557  	c.entries[key] = entry
   558  	return entry
   559  }
   560  
   561  type lruCacheEntry struct {
   562  	key           string
   563  	accessTimeElt *list.Element
   564  	loadTimeElt   *list.Element
   565  	loadingCh     chan struct{}
   566  	expiresAt     time.Time
   567  	err           error
   568  	value         interface{}
   569  }
   570  
   571  type lruCacheMetrics struct {
   572  	entries       tally.Gauge
   573  	hits          tally.Counter
   574  	misses        tally.Counter
   575  	loadAttempts  tally.Counter
   576  	loadSuccesses tally.Counter
   577  	loadFailures  tally.Counter
   578  	loadTimes     tally.Histogram
   579  }
   580  
   581  var _ Cache = &LRU{}