github.com/phuslu/lru@v1.0.16-0.20240421170520-46288a2fd47c/ttl_cache.go (about)

     1  // Copyright 2023-2024 Phus Lu. All rights reserved.
     2  
     3  package lru
     4  
     5  import (
     6  	"context"
     7  	"sync/atomic"
     8  	"time"
     9  	"unsafe"
    10  )
    11  
    12  // TTLCache implements LRU Cache with TTL functionality.
    13  type TTLCache[K comparable, V any] struct {
    14  	shards [512]ttlshard[K, V]
    15  	mask   uint32
    16  	hasher func(key unsafe.Pointer, seed uintptr) uintptr
    17  	seed   uintptr
    18  	loader func(ctx context.Context, key K) (value V, ttl time.Duration, err error)
    19  	group  singleflight_Group[K, V]
    20  }
    21  
    22  // NewTTLCache creates lru cache with size capacity.
    23  func NewTTLCache[K comparable, V any](size int, options ...Option[K, V]) *TTLCache[K, V] {
    24  	clocking()
    25  
    26  	j := -1
    27  	for i, o := range options {
    28  		if _, ok := o.(*shardsOption[K, V]); ok {
    29  			j = i
    30  		}
    31  	}
    32  	switch {
    33  	case j < 0:
    34  		options = append([]Option[K, V]{WithShards[K, V](0)}, options...)
    35  	case j > 0:
    36  		options[0], options[j] = options[j], options[0]
    37  	}
    38  
    39  	c := new(TTLCache[K, V])
    40  	for _, o := range options {
    41  		o.applyToTTLCache(c)
    42  	}
    43  
    44  	if c.hasher == nil {
    45  		c.hasher = getRuntimeHasher[K]()
    46  	}
    47  	if c.seed == 0 {
    48  		c.seed = uintptr(fastrand64())
    49  	}
    50  
    51  	if isamd64 {
    52  		// pre-alloc lists and tables for compactness
    53  		shardsize := (uint32(size) + c.mask) / (c.mask + 1)
    54  		shardlists := make([]ttlnode[K, V], (shardsize+1)*(c.mask+1))
    55  		tablesize := ttlNewTableSize(uint32(shardsize))
    56  		tablebuckets := make([]uint64, tablesize*(c.mask+1))
    57  		for i := uint32(0); i <= c.mask; i++ {
    58  			c.shards[i].list = shardlists[i*(shardsize+1) : (i+1)*(shardsize+1)]
    59  			c.shards[i].table_buckets = tablebuckets[i*tablesize : (i+1)*tablesize]
    60  			c.shards[i].Init(shardsize, c.hasher, c.seed)
    61  		}
    62  	} else {
    63  		shardsize := (uint32(size) + c.mask) / (c.mask + 1)
    64  		for i := uint32(0); i <= c.mask; i++ {
    65  			c.shards[i].Init(shardsize, c.hasher, c.seed)
    66  		}
    67  	}
    68  
    69  	return c
    70  }
    71  
    72  // Get returns value for key.
    73  func (c *TTLCache[K, V]) Get(key K) (value V, ok bool) {
    74  	hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
    75  	// return c.shards[hash&c.mask].Get(hash, key)
    76  	return (*ttlshard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).Get(hash, key)
    77  }
    78  
    79  // GetOrLoad returns value for key, call loader function by singleflight if value was not in cache.
    80  func (c *TTLCache[K, V]) GetOrLoad(ctx context.Context, key K, loader func(context.Context, K) (V, time.Duration, error)) (value V, err error, ok bool) {
    81  	hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
    82  	// value, ok = c.shards[hash&c.mask].Get(hash, key)
    83  	value, ok = (*ttlshard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).Get(hash, key)
    84  	if !ok {
    85  		if loader == nil {
    86  			loader = c.loader
    87  		}
    88  		if loader == nil {
    89  			err = ErrLoaderIsNil
    90  			return
    91  		}
    92  		value, err, ok = c.group.Do(key, func() (V, error) {
    93  			v, ttl, err := loader(ctx, key)
    94  			if err != nil {
    95  				return v, err
    96  			}
    97  			c.shards[hash&c.mask].Set(hash, key, v, ttl)
    98  			return v, nil
    99  		})
   100  	}
   101  	return
   102  }
   103  
   104  // Peek returns value and expires nanoseconds for key, but does not modify its recency.
   105  func (c *TTLCache[K, V]) Peek(key K) (value V, expires int64, ok bool) {
   106  	hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
   107  	// return c.shards[hash&c.mask].Peek(hash, key)
   108  	return (*ttlshard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).Peek(hash, key)
   109  }
   110  
   111  // Set inserts key value pair and returns previous value.
   112  func (c *TTLCache[K, V]) Set(key K, value V, ttl time.Duration) (prev V, replaced bool) {
   113  	hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
   114  	// return c.shards[hash&c.mask].Set(hash, key, value, ttl)
   115  	return (*ttlshard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).Set(hash, key, value, ttl)
   116  }
   117  
   118  // SetIfAbsent inserts key value pair and returns previous value, if key is absent in the cache.
   119  func (c *TTLCache[K, V]) SetIfAbsent(key K, value V, ttl time.Duration) (prev V, replaced bool) {
   120  	hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
   121  	// return c.shards[hash&c.mask].SetIfAbsent(hash, key, value, ttl)
   122  	return (*ttlshard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).SetIfAbsent(hash, key, value, ttl)
   123  }
   124  
   125  // Delete method deletes value associated with key and returns deleted value (or empty value if key was not in cache).
   126  func (c *TTLCache[K, V]) Delete(key K) (prev V) {
   127  	hash := uint32(c.hasher(noescape(unsafe.Pointer(&key)), c.seed))
   128  	// return c.shards[hash&c.mask].Delete(hash, key)
   129  	return (*ttlshard[K, V])(unsafe.Add(unsafe.Pointer(&c.shards[0]), uintptr(hash&c.mask)*unsafe.Sizeof(c.shards[0]))).Delete(hash, key)
   130  }
   131  
   132  // Len returns number of cached nodes.
   133  func (c *TTLCache[K, V]) Len() int {
   134  	var n uint32
   135  	for i := uint32(0); i <= c.mask; i++ {
   136  		n += c.shards[i].Len()
   137  	}
   138  	return int(n)
   139  }
   140  
   141  // AppendKeys appends all keys to keys and return the keys.
   142  func (c *TTLCache[K, V]) AppendKeys(keys []K) []K {
   143  	now := atomic.LoadUint32(&clock)
   144  	for i := uint32(0); i <= c.mask; i++ {
   145  		keys = c.shards[i].AppendKeys(keys, now)
   146  	}
   147  	return keys
   148  }
   149  
   150  // Stats returns cache stats.
   151  func (c *TTLCache[K, V]) Stats() (stats Stats) {
   152  	for i := uint32(0); i <= c.mask; i++ {
   153  		s := &c.shards[i]
   154  		s.mu.Lock()
   155  		stats.EntriesCount += uint64(s.table_length)
   156  		stats.GetCalls += s.stats_getcalls
   157  		stats.SetCalls += s.stats_setcalls
   158  		stats.Misses += s.stats_misses
   159  		s.mu.Unlock()
   160  	}
   161  	return
   162  }