github.com/rudderlabs/rudder-go-kit@v0.30.0/cachettl/cachettl.go (about)

     1  package cachettl
     2  
     3  import (
     4  	"sync"
     5  	"time"
     6  )
     7  
     8  // Cache is a double linked list sorted by expiration time (ascending order)
     9  // the root (head) node is the node with the lowest expiration time
    10  // the tail node (end) is the node with the highest expiration time
    11  // Cleanups are done on Get() calls so if Get() is never invoked then Nodes stay in-memory.
    12  type Cache[K comparable, V any] struct {
    13  	root *node[K, V]
    14  	mu   sync.Mutex
    15  	m    map[K]*node[K, V]
    16  	now  func() time.Time
    17  }
    18  
    19  type node[K comparable, V any] struct {
    20  	key        K
    21  	value      V
    22  	prev       *node[K, V]
    23  	next       *node[K, V]
    24  	ttl        time.Duration
    25  	expiration time.Time
    26  }
    27  
    28  func (n *node[K, V]) remove() {
    29  	n.prev.next = n.next
    30  	n.next.prev = n.prev
    31  }
    32  
    33  // New returns a new Cache.
    34  func New[K comparable, V any]() *Cache[K, V] {
    35  	return &Cache[K, V]{
    36  		now:  time.Now,
    37  		root: &node[K, V]{},
    38  		m:    make(map[K]*node[K, V]),
    39  	}
    40  }
    41  
    42  // Get returns the value associated with the key or nil otherwise.
    43  // Additionally, Get() will refresh the TTL and cleanup expired nodes.
    44  func (c *Cache[K, V]) Get(key K) (zero V) {
    45  	c.mu.Lock()
    46  	defer c.mu.Unlock()
    47  
    48  	defer func() { // remove expired nodes
    49  		cn := c.root.next // start from head since we're sorting by expiration with the highest expiration at the tail
    50  		for cn != nil && cn != c.root {
    51  			if c.now().After(cn.expiration) {
    52  				cn.remove()         // removes a node from the linked list (leaves the map untouched)
    53  				delete(c.m, cn.key) // remove node from map too
    54  			} else { // there is nothing else to clean up, no need to iterate further
    55  				break
    56  			}
    57  			cn = cn.next
    58  		}
    59  	}()
    60  
    61  	if n, ok := c.m[key]; ok && n.expiration.After(c.now()) {
    62  		n.remove()
    63  		n.expiration = c.now().Add(n.ttl) // refresh TTL
    64  		c.add(n)
    65  		return n.value
    66  	}
    67  	return zero
    68  }
    69  
    70  // Put adds or updates an element inside the Cache.
    71  // The Cache will be sorted with the node with the highest expiration at the tail.
    72  func (c *Cache[K, V]) Put(key K, value V, ttl time.Duration) {
    73  	c.mu.Lock()
    74  	defer c.mu.Unlock()
    75  
    76  	now := c.now()
    77  
    78  	n, ok := c.m[key]
    79  	if !ok {
    80  		n = &node[K, V]{
    81  			key: key, value: value, ttl: ttl, expiration: now.Add(ttl),
    82  		}
    83  		c.m[key] = n
    84  	} else {
    85  		n.value = value
    86  		n.expiration = now.Add(ttl)
    87  	}
    88  
    89  	if c.root.next == nil { // first node insertion
    90  		c.root.next = n
    91  		c.root.prev = n
    92  		n.prev = c.root
    93  		n.next = c.root
    94  		return
    95  	}
    96  
    97  	if ok { // removes a node from the linked list (leaves the map untouched)
    98  		n.remove()
    99  	}
   100  
   101  	c.add(n)
   102  }
   103  
   104  func (c *Cache[K, V]) add(n *node[K, V]) {
   105  	cn := c.root.prev // tail
   106  	for cn != nil {   // iterate from tail to root because we have expiring nodes towards the tail
   107  		if n.expiration.After(cn.expiration) || n.expiration.Equal(cn.expiration) {
   108  			// insert node after cn
   109  			save := cn.next
   110  			cn.next = n
   111  			n.prev = cn
   112  			n.next = save
   113  			save.prev = n
   114  			break
   115  		}
   116  		cn = cn.prev
   117  	}
   118  }
   119  
   120  // slice is used for debugging purposes only
   121  func (c *Cache[K, V]) slice() (s []V) {
   122  	c.mu.Lock()
   123  	defer c.mu.Unlock()
   124  
   125  	cn := c.root.next
   126  	for cn != nil && cn != c.root {
   127  		s = append(s, cn.value)
   128  		cn = cn.next
   129  	}
   130  	return
   131  }