github.com/badrootd/celestia-core@v0.0.0-20240305091328-aa4207a4b25d/mempool/cat/cache.go (about)

     1  package cat
     2  
     3  import (
     4  	"container/list"
     5  	"time"
     6  
     7  	tmsync "github.com/badrootd/celestia-core/libs/sync"
     8  	"github.com/badrootd/celestia-core/types"
     9  )
    10  
    11  // LRUTxCache maintains a thread-safe LRU cache of raw transactions. The cache
    12  // only stores the hash of the raw transaction.
    13  // NOTE: This has been copied from mempool/cache with the main difference of using
    14  // tx keys instead of raw transactions.
    15  type LRUTxCache struct {
    16  	staticSize int
    17  
    18  	mtx tmsync.Mutex
    19  	// cacheMap is used as a quick look up table
    20  	cacheMap map[types.TxKey]*list.Element
    21  	// list is a doubly linked list used to capture the FIFO nature of the cache
    22  	list *list.List
    23  }
    24  
    25  func NewLRUTxCache(cacheSize int) *LRUTxCache {
    26  	return &LRUTxCache{
    27  		staticSize: cacheSize,
    28  		cacheMap:   make(map[types.TxKey]*list.Element, cacheSize),
    29  		list:       list.New(),
    30  	}
    31  }
    32  
    33  func (c *LRUTxCache) Reset() {
    34  	c.mtx.Lock()
    35  	defer c.mtx.Unlock()
    36  
    37  	c.cacheMap = make(map[types.TxKey]*list.Element, c.staticSize)
    38  	c.list.Init()
    39  }
    40  
    41  func (c *LRUTxCache) Push(txKey types.TxKey) bool {
    42  	if c.staticSize == 0 {
    43  		return true
    44  	}
    45  
    46  	c.mtx.Lock()
    47  	defer c.mtx.Unlock()
    48  
    49  	moved, ok := c.cacheMap[txKey]
    50  	if ok {
    51  		c.list.MoveToBack(moved)
    52  		return false
    53  	}
    54  
    55  	if c.list.Len() >= c.staticSize {
    56  		front := c.list.Front()
    57  		if front != nil {
    58  			frontKey := front.Value.(types.TxKey)
    59  			delete(c.cacheMap, frontKey)
    60  			c.list.Remove(front)
    61  		}
    62  	}
    63  
    64  	e := c.list.PushBack(txKey)
    65  	c.cacheMap[txKey] = e
    66  
    67  	return true
    68  }
    69  
    70  func (c *LRUTxCache) Remove(txKey types.TxKey) {
    71  	if c.staticSize == 0 {
    72  		return
    73  	}
    74  
    75  	c.mtx.Lock()
    76  	defer c.mtx.Unlock()
    77  
    78  	e := c.cacheMap[txKey]
    79  	delete(c.cacheMap, txKey)
    80  
    81  	if e != nil {
    82  		c.list.Remove(e)
    83  	}
    84  }
    85  
    86  func (c *LRUTxCache) Has(txKey types.TxKey) bool {
    87  	if c.staticSize == 0 {
    88  		return false
    89  	}
    90  
    91  	c.mtx.Lock()
    92  	defer c.mtx.Unlock()
    93  
    94  	_, ok := c.cacheMap[txKey]
    95  	return ok
    96  }
    97  
    98  // SeenTxSet records transactions that have been
    99  // seen by other peers but not yet by us
   100  type SeenTxSet struct {
   101  	mtx tmsync.Mutex
   102  	set map[types.TxKey]timestampedPeerSet
   103  }
   104  
   105  type timestampedPeerSet struct {
   106  	peers map[uint16]struct{}
   107  	time  time.Time
   108  }
   109  
   110  func NewSeenTxSet() *SeenTxSet {
   111  	return &SeenTxSet{
   112  		set: make(map[types.TxKey]timestampedPeerSet),
   113  	}
   114  }
   115  
   116  func (s *SeenTxSet) Add(txKey types.TxKey, peer uint16) {
   117  	if peer == 0 {
   118  		return
   119  	}
   120  	s.mtx.Lock()
   121  	defer s.mtx.Unlock()
   122  	seenSet, exists := s.set[txKey]
   123  	if !exists {
   124  		s.set[txKey] = timestampedPeerSet{
   125  			peers: map[uint16]struct{}{peer: {}},
   126  			time:  time.Now().UTC(),
   127  		}
   128  	} else {
   129  		seenSet.peers[peer] = struct{}{}
   130  	}
   131  }
   132  
   133  func (s *SeenTxSet) RemoveKey(txKey types.TxKey) {
   134  	s.mtx.Lock()
   135  	defer s.mtx.Unlock()
   136  	delete(s.set, txKey)
   137  }
   138  
   139  func (s *SeenTxSet) Remove(txKey types.TxKey, peer uint16) {
   140  	s.mtx.Lock()
   141  	defer s.mtx.Unlock()
   142  	set, exists := s.set[txKey]
   143  	if exists {
   144  		if len(set.peers) == 1 {
   145  			delete(s.set, txKey)
   146  		} else {
   147  			delete(set.peers, peer)
   148  		}
   149  	}
   150  }
   151  
   152  func (s *SeenTxSet) RemovePeer(peer uint16) {
   153  	s.mtx.Lock()
   154  	defer s.mtx.Unlock()
   155  	for key, seenSet := range s.set {
   156  		delete(seenSet.peers, peer)
   157  		if len(seenSet.peers) == 0 {
   158  			delete(s.set, key)
   159  		}
   160  	}
   161  }
   162  
   163  func (s *SeenTxSet) Prune(limit time.Time) {
   164  	s.mtx.Lock()
   165  	defer s.mtx.Unlock()
   166  	for key, seenSet := range s.set {
   167  		if seenSet.time.Before(limit) {
   168  			delete(s.set, key)
   169  		}
   170  	}
   171  }
   172  
   173  func (s *SeenTxSet) Has(txKey types.TxKey, peer uint16) bool {
   174  	s.mtx.Lock()
   175  	defer s.mtx.Unlock()
   176  	seenSet, exists := s.set[txKey]
   177  	if !exists {
   178  		return false
   179  	}
   180  	_, has := seenSet.peers[peer]
   181  	return has
   182  }
   183  
   184  func (s *SeenTxSet) Get(txKey types.TxKey) map[uint16]struct{} {
   185  	s.mtx.Lock()
   186  	defer s.mtx.Unlock()
   187  	seenSet, exists := s.set[txKey]
   188  	if !exists {
   189  		return nil
   190  	}
   191  	// make a copy of the struct to avoid concurrency issues
   192  	peers := make(map[uint16]struct{}, len(seenSet.peers))
   193  	for peer := range seenSet.peers {
   194  		peers[peer] = struct{}{}
   195  	}
   196  	return peers
   197  }
   198  
   199  // Len returns the amount of cached items. Mostly used for testing.
   200  func (s *SeenTxSet) Len() int {
   201  	s.mtx.Lock()
   202  	defer s.mtx.Unlock()
   203  	return len(s.set)
   204  }
   205  
   206  func (s *SeenTxSet) Reset() {
   207  	s.mtx.Lock()
   208  	defer s.mtx.Unlock()
   209  	s.set = make(map[types.TxKey]timestampedPeerSet)
   210  }