github.com/phuslu/lru@v1.0.16-0.20240421170520-46288a2fd47c/lru_cache.go (about)

     1  // Copyright 2023-2024 Phus Lu. All rights reserved.
     2  
     3  // Package lru implements cache with least recent used eviction policy.
     4  package lru
     5  
     6  import (
     7  	"context"
     8  	"unsafe"
     9  )
    10  
    11  // LRUCache implements LRU Cache with least recent used eviction policy.
    12  type LRUCache[K comparable, V any] struct {
    13  	shards [512]lrushard[K, V]
    14  	mask   uint32
    15  	hasher func(key unsafe.Pointer, seed uintptr) uintptr
    16  	seed   uintptr
    17  	loader func(ctx context.Context, key K) (value V, err error)
    18  	group  singleflight_Group[K, V]
    19  }
    20  
    21  // NewLRUCache creates lru cache with size capacity.
    22  func NewLRUCache[K comparable, V any](size int, options ...Option[K, V]) *LRUCache[K, V] {
    23  	j := -1
    24  	for i, o := range options {
    25  		if _, ok := o.(*shardsOption[K, V]); ok {
    26  			j = i
    27  		}
    28  	}
    29  	switch {
    30  	case j < 0:
    31  		options = append([]Option[K, V]{WithShards[K, V](0)}, options...)
    32  	case j > 0:
    33  		options[0], options[j] = options[j], options[0]
    34  	}
    35  
    36  	c := new(LRUCache[K, V])
    37  	for _, o := range options {
    38  		o.applyToLRUCache(c)
    39  	}
    40  
    41  	if c.hasher == nil {
    42  		c.hasher = getRuntimeHasher[K]()
    43  	}
    44  	if c.seed == 0 {
    45  		c.seed = uintptr(fastrand64())
    46  	}
    47  
    48  	if isamd64 {
    49  		// pre-alloc lists and tables for compactness
    50  		shardsize := (uint32(size) + c.mask) / (c.mask + 1)
    51  		shardlists := make([]lrunode[K, V], (shardsize+1)*(c.mask+1))
    52  		tablesize := lruNewTableSize(uint32(shardsize))
    53  		tablebuckets := make([]uint64, tablesize*(c.mask+1))
    54  		for i := uint32(0); i <= c.mask; i++ {
    55  			c.shards[i].list = shardlists[i*(shardsize+1) : (i+1)*(shardsize+1)]
    56  			c.shards[i].table_buckets = tablebuckets[i*tablesize : (i+1)*tablesize]
    57  			c.shards[i].Init(shardsize, c.hasher, c.seed)
    58  		}
    59  	} else {
    60  		shardsize := (uint32(size) + c.mask) / (c.mask + 1)
    61  		for i := uint32(0); i <= c.mask; i++ {
    62  			c.shards[i].Init(shardsize, c.hasher, c.seed)
    63  		}
    64  	}
    65  
    66  	return c
    67  }
    68  
    69  // Get returns value for key.
    70  func (c *LRUCache[K, V]) Get(key K) (value V, ok bool) {
    71  	hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
    72  	// return c.shards[hash&c.mask].Get(hash, key)
    73  	return (*lrushard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).Get(hash, key)
    74  }
    75  
    76  // GetOrLoad returns value for key, call loader function by singleflight if value was not in cache.
    77  func (c *LRUCache[K, V]) GetOrLoad(ctx context.Context, key K, loader func(context.Context, K) (V, error)) (value V, err error, ok bool) {
    78  	hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
    79  	value, ok = c.shards[hash&c.mask].Get(hash, key)
    80  	if !ok {
    81  		if loader == nil {
    82  			loader = c.loader
    83  		}
    84  		if loader == nil {
    85  			err = ErrLoaderIsNil
    86  			return
    87  		}
    88  		value, err, ok = c.group.Do(key, func() (V, error) {
    89  			v, err := loader(ctx, key)
    90  			if err != nil {
    91  				return v, err
    92  			}
    93  			c.shards[hash&c.mask].Set(hash, key, v)
    94  			return v, nil
    95  		})
    96  	}
    97  	return
    98  }
    99  
   100  // Peek returns value, but does not modify its recency.
   101  func (c *LRUCache[K, V]) Peek(key K) (value V, ok bool) {
   102  	hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
   103  	// return c.shards[hash&c.mask].Peek(hash, key)
   104  	return (*lrushard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).Peek(hash, key)
   105  }
   106  
   107  // Set inserts key value pair and returns previous value.
   108  func (c *LRUCache[K, V]) Set(key K, value V) (prev V, replaced bool) {
   109  	hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
   110  	// return c.shards[hash&c.mask].Set(hash, key, value)
   111  	return (*lrushard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).Set(hash, key, value)
   112  }
   113  
   114  // SetIfAbsent inserts key value pair and returns previous value, if key is absent in the cache.
   115  func (c *LRUCache[K, V]) SetIfAbsent(key K, value V) (prev V, replaced bool) {
   116  	hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
   117  	// return c.shards[hash&c.mask].SetIfAbsent(hash, key, value)
   118  	return (*lrushard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).SetIfAbsent(hash, key, value)
   119  }
   120  
   121  // Delete method deletes value associated with key and returns deleted value (or empty value if key was not in cache).
   122  func (c *LRUCache[K, V]) Delete(key K) (prev V) {
   123  	hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
   124  	// return c.shards[hash&c.mask].Delete(hash, key)
   125  	return (*lrushard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).Delete(hash, key)
   126  }
   127  
   128  // Len returns number of cached nodes.
   129  func (c *LRUCache[K, V]) Len() int {
   130  	var n uint32
   131  	for i := uint32(0); i <= c.mask; i++ {
   132  		n += c.shards[i].Len()
   133  	}
   134  	return int(n)
   135  }
   136  
   137  // AppendKeys appends all keys to keys and return the keys.
   138  func (c *LRUCache[K, V]) AppendKeys(keys []K) []K {
   139  	for i := uint32(0); i <= c.mask; i++ {
   140  		keys = c.shards[i].AppendKeys(keys)
   141  	}
   142  	return keys
   143  }
   144  
   145  // Stats returns cache stats.
   146  func (c *LRUCache[K, V]) Stats() (stats Stats) {
   147  	for i := uint32(0); i <= c.mask; i++ {
   148  		s := &c.shards[i]
   149  		s.mu.Lock()
   150  		stats.EntriesCount += uint64(s.table_length)
   151  		stats.GetCalls += s.stats_getcalls
   152  		stats.SetCalls += s.stats_setcalls
   153  		stats.Misses += s.stats_misses
   154  		s.mu.Unlock()
   155  	}
   156  	return
   157  }