github.com/badrootd/celestia-core@v0.0.0-20240305091328-aa4207a4b25d/mempool/cat/cache_test.go (about)

     1  package cat
     2  
     3  import (
     4  	"crypto/rand"
     5  	"fmt"
     6  	"sync"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/badrootd/celestia-core/types"
    13  )
    14  
    15  func TestSeenTxSet(t *testing.T) {
    16  	var (
    17  		tx1Key        = types.Tx("tx1").Key()
    18  		tx2Key        = types.Tx("tx2").Key()
    19  		tx3Key        = types.Tx("tx3").Key()
    20  		peer1  uint16 = 1
    21  		peer2  uint16 = 2
    22  	)
    23  
    24  	seenSet := NewSeenTxSet()
    25  	require.Nil(t, seenSet.Get(tx1Key))
    26  
    27  	seenSet.Add(tx1Key, peer1)
    28  	seenSet.Add(tx1Key, peer1)
    29  	require.Equal(t, 1, seenSet.Len())
    30  	seenSet.Add(tx1Key, peer2)
    31  	peers := seenSet.Get(tx1Key)
    32  	require.NotNil(t, peers)
    33  	require.Equal(t, map[uint16]struct{}{peer1: {}, peer2: {}}, peers)
    34  	seenSet.Add(tx2Key, peer1)
    35  	seenSet.Add(tx3Key, peer1)
    36  	require.Equal(t, 3, seenSet.Len())
    37  	seenSet.RemoveKey(tx2Key)
    38  	require.Equal(t, 2, seenSet.Len())
    39  	require.Nil(t, seenSet.Get(tx2Key))
    40  	require.True(t, seenSet.Has(tx3Key, peer1))
    41  }
    42  
    43  func TestLRUTxCacheRemove(t *testing.T) {
    44  	cache := NewLRUTxCache(100)
    45  	numTxs := 10
    46  
    47  	txs := make([][32]byte, numTxs)
    48  	for i := 0; i < numTxs; i++ {
    49  		// probability of collision is 2**-256
    50  		txBytes := make([]byte, 32)
    51  		_, err := rand.Read(txBytes)
    52  		require.NoError(t, err)
    53  
    54  		copy(txs[i][:], txBytes)
    55  		cache.Push(txs[i])
    56  
    57  		// make sure its added to both the linked list and the map
    58  		require.Equal(t, i+1, cache.list.Len())
    59  	}
    60  
    61  	for i := 0; i < numTxs; i++ {
    62  		cache.Remove(txs[i])
    63  		// make sure its removed from both the map and the linked list
    64  		require.Equal(t, numTxs-(i+1), cache.list.Len())
    65  	}
    66  }
    67  
    68  func TestLRUTxCacheSize(t *testing.T) {
    69  	const size = 10
    70  	cache := NewLRUTxCache(size)
    71  
    72  	for i := 0; i < size*2; i++ {
    73  		tx := types.Tx([]byte(fmt.Sprintf("tx%d", i)))
    74  		cache.Push(tx.Key())
    75  		require.Less(t, cache.list.Len(), size+1)
    76  	}
    77  }
    78  
    79  func TestSeenTxSetConcurrency(t *testing.T) {
    80  	seenSet := NewSeenTxSet()
    81  
    82  	const (
    83  		concurrency = 10
    84  		numTx       = 100
    85  	)
    86  
    87  	wg := sync.WaitGroup{}
    88  	for i := 0; i < concurrency; i++ {
    89  		wg.Add(1)
    90  		go func(peer uint16) {
    91  			defer wg.Done()
    92  			for i := 0; i < numTx; i++ {
    93  				tx := types.Tx([]byte(fmt.Sprintf("tx%d", i)))
    94  				seenSet.Add(tx.Key(), peer)
    95  			}
    96  		}(uint16(i % 2))
    97  	}
    98  	time.Sleep(time.Millisecond)
    99  	for i := 0; i < concurrency; i++ {
   100  		wg.Add(1)
   101  		go func(peer uint16) {
   102  			defer wg.Done()
   103  			for i := 0; i < numTx; i++ {
   104  				tx := types.Tx([]byte(fmt.Sprintf("tx%d", i)))
   105  				seenSet.Has(tx.Key(), peer)
   106  			}
   107  		}(uint16(i % 2))
   108  	}
   109  	time.Sleep(time.Millisecond)
   110  	for i := 0; i < concurrency; i++ {
   111  		wg.Add(1)
   112  		go func(peer uint16) {
   113  			defer wg.Done()
   114  			for i := numTx - 1; i >= 0; i-- {
   115  				tx := types.Tx([]byte(fmt.Sprintf("tx%d", i)))
   116  				seenSet.RemoveKey(tx.Key())
   117  			}
   118  		}(uint16(i % 2))
   119  	}
   120  	wg.Wait()
   121  }
   122  
   123  func TestLRUTxCacheConcurrency(t *testing.T) {
   124  	cache := NewLRUTxCache(100)
   125  
   126  	const (
   127  		concurrency = 10
   128  		numTx       = 100
   129  	)
   130  
   131  	wg := sync.WaitGroup{}
   132  	for i := 0; i < concurrency; i++ {
   133  		wg.Add(1)
   134  		go func() {
   135  			defer wg.Done()
   136  			for i := 0; i < numTx; i++ {
   137  				tx := types.Tx([]byte(fmt.Sprintf("tx%d", i)))
   138  				cache.Push(tx.Key())
   139  			}
   140  			for i := 0; i < numTx; i++ {
   141  				tx := types.Tx([]byte(fmt.Sprintf("tx%d", i)))
   142  				cache.Has(tx.Key())
   143  			}
   144  			for i := numTx - 1; i >= 0; i-- {
   145  				tx := types.Tx([]byte(fmt.Sprintf("tx%d", i)))
   146  				cache.Remove(tx.Key())
   147  			}
   148  		}()
   149  	}
   150  	wg.Wait()
   151  }