github.com/koko1123/flow-go-1@v0.29.6/module/mempool/stdmap/backend_test.go (about)

     1  // (c) 2019 Dapper Labs - ALL RIGHTS RESERVED
     2  
     3  package stdmap_test
     4  
     5  import (
     6  	"fmt"
     7  	"sync"
     8  	"sync/atomic"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  
    15  	"github.com/koko1123/flow-go-1/model/flow"
    16  	"github.com/koko1123/flow-go-1/module/mempool"
    17  	"github.com/koko1123/flow-go-1/module/mempool/stdmap"
    18  	"github.com/koko1123/flow-go-1/utils/unittest"
    19  )
    20  
    21  func TestAddRemove(t *testing.T) {
    22  	item1 := unittest.MockEntityFixture()
    23  	item2 := unittest.MockEntityFixture()
    24  
    25  	t.Run("should be able to add and rem", func(t *testing.T) {
    26  		pool := stdmap.NewBackend()
    27  		added := pool.Add(item1)
    28  		require.True(t, added)
    29  		added = pool.Add(item2)
    30  		require.True(t, added)
    31  
    32  		t.Run("should be able to get size", func(t *testing.T) {
    33  			size := pool.Size()
    34  			assert.EqualValues(t, uint(2), size)
    35  		})
    36  
    37  		t.Run("should be able to get first", func(t *testing.T) {
    38  			gotItem, exists := pool.ByID(item1.ID())
    39  			assert.True(t, exists)
    40  			assert.Equal(t, item1, gotItem)
    41  		})
    42  
    43  		t.Run("should be able to remove first", func(t *testing.T) {
    44  			removed := pool.Remove(item1.ID())
    45  			assert.True(t, removed)
    46  			size := pool.Size()
    47  			assert.EqualValues(t, uint(1), size)
    48  		})
    49  
    50  		t.Run("should be able to retrieve all", func(t *testing.T) {
    51  			items := pool.All()
    52  			require.Len(t, items, 1)
    53  			assert.Equal(t, item2, items[0])
    54  		})
    55  	})
    56  }
    57  
    58  func TestAdjust(t *testing.T) {
    59  	item1 := unittest.MockEntityFixture()
    60  	item2 := unittest.MockEntityFixture()
    61  
    62  	t.Run("should not adjust if not exist", func(t *testing.T) {
    63  		pool := stdmap.NewBackend()
    64  		_ = pool.Add(item1)
    65  
    66  		// item2 doesn't exist
    67  		updatedItem, updated := pool.Adjust(item2.ID(), func(old flow.Entity) flow.Entity {
    68  			return item2
    69  		})
    70  
    71  		assert.False(t, updated)
    72  		assert.Nil(t, updatedItem)
    73  
    74  		_, found := pool.ByID(item2.ID())
    75  		assert.False(t, found)
    76  	})
    77  
    78  	t.Run("should adjust if exists", func(t *testing.T) {
    79  		pool := stdmap.NewBackend()
    80  		_ = pool.Add(item1)
    81  
    82  		updatedItem, ok := pool.Adjust(item1.ID(), func(old flow.Entity) flow.Entity {
    83  			// item 1 exist, got replaced with item2, the value was updated
    84  			return item2
    85  		})
    86  
    87  		assert.True(t, ok)
    88  		assert.Equal(t, updatedItem, item2)
    89  
    90  		value2, found := pool.ByID(item2.ID())
    91  		assert.True(t, found)
    92  		assert.Equal(t, value2, item2)
    93  	})
    94  }
    95  
    96  // Test that size mempool de-duplicates based on ID
    97  func Test_DeduplicationByID(t *testing.T) {
    98  	item1 := unittest.MockEntityFixture()
    99  	item2 := unittest.MockEntity{Identifier: item1.Identifier} // duplicate
   100  	assert.True(t, item1.ID() == item2.ID())
   101  
   102  	pool := stdmap.NewBackend()
   103  	pool.Add(item1)
   104  	pool.Add(item2)
   105  	assert.Equal(t, uint(1), pool.Size())
   106  }
   107  
   108  // TestBackend_RunLimitChecking defines a backend with size limit of `limit`. It then
   109  // starts adding `swarm`-many items concurrently to the backend each on a separate goroutine,
   110  // where `swarm` > `limit`,
   111  // and evaluates that size of the map stays within the limit.
   112  func TestBackend_RunLimitChecking(t *testing.T) {
   113  	const (
   114  		limit = 150
   115  		swarm = 150
   116  	)
   117  	pool := stdmap.NewBackend(stdmap.WithLimit(limit))
   118  
   119  	wg := sync.WaitGroup{}
   120  	wg.Add(swarm)
   121  
   122  	for i := 0; i < swarm; i++ {
   123  		go func(x int) {
   124  			// creates and adds a fake item to the mempool
   125  			item := unittest.MockEntityFixture()
   126  			_ = pool.Run(func(backdata mempool.BackData) error {
   127  				added := backdata.Add(item.ID(), item)
   128  				if !added {
   129  					return fmt.Errorf("potential race condition on adding to back data")
   130  				}
   131  
   132  				return nil
   133  			})
   134  
   135  			// evaluates that the size remains in the permissible range
   136  			require.True(t, pool.Size() <= uint(limit),
   137  				fmt.Sprintf("size violation: should be at most: %d, got: %d", limit, pool.Size()))
   138  			wg.Done()
   139  		}(i)
   140  	}
   141  
   142  	unittest.RequireReturnsBefore(t, wg.Wait, 1*time.Second, "test could not finish on time")
   143  }
   144  
   145  // TestBackend_RegisterEjectionCallback verifies that the Backend calls the
   146  // ejection callbacks whenever it ejects a stored entity due to size limitations.
   147  func TestBackend_RegisterEjectionCallback(t *testing.T) {
   148  	const (
   149  		limit = 20
   150  		swarm = 20
   151  	)
   152  	pool := stdmap.NewBackend(stdmap.WithLimit(limit))
   153  
   154  	// on ejection callback: test whether ejected identity is no longer part of the mempool
   155  	ensureEntityNotInMempool := func(entity flow.Entity) {
   156  		id := entity.ID()
   157  		go func() {
   158  			e, found := pool.ByID(id)
   159  			require.False(t, found)
   160  			require.Nil(t, e)
   161  		}()
   162  		go func() {
   163  			require.False(t, pool.Has(id))
   164  		}()
   165  	}
   166  	pool.RegisterEjectionCallbacks(ensureEntityNotInMempool)
   167  
   168  	wg := sync.WaitGroup{}
   169  	wg.Add(swarm)
   170  	for i := 0; i < swarm; i++ {
   171  		go func(x int) {
   172  			// creates and adds a fake item to the mempool
   173  			item := unittest.MockEntityFixture()
   174  			pool.Add(item)
   175  			wg.Done()
   176  		}(i)
   177  	}
   178  
   179  	unittest.RequireReturnsBefore(t, wg.Wait, 1*time.Second, "test could not finish on time")
   180  	require.Equal(t, uint(limit), pool.Size(), "expected mempool to be at max capacity limit")
   181  }
   182  
   183  // TestBackend_Multiple_OnEjectionCallbacks verifies that the Backend
   184  // handles multiple ejection callbacks correctly
   185  func TestBackend_Multiple_OnEjectionCallbacks(t *testing.T) {
   186  	// ejection callback counts number of calls
   187  	calls := uint64(0)
   188  	callback := func(entity flow.Entity) {
   189  		atomic.AddUint64(&calls, 1)
   190  	}
   191  
   192  	// construct backend
   193  	const (
   194  		limit = 30
   195  	)
   196  	pool := stdmap.NewBackend(stdmap.WithLimit(limit))
   197  	pool.RegisterEjectionCallbacks(callback, callback)
   198  
   199  	t.Run("fill mempool up to limit", func(t *testing.T) {
   200  		addRandomEntities(t, pool, limit)
   201  		require.Equal(t, uint(limit), pool.Size(), "expected mempool to be at max capacity limit")
   202  		require.Equal(t, uint64(0), atomic.LoadUint64(&calls))
   203  	})
   204  
   205  	t.Run("add elements beyond limit", func(t *testing.T) {
   206  		addRandomEntities(t, pool, 2) // as we registered callback _twice_, we should receive 2 calls per ejection
   207  		require.Less(t, uint(limit), pool.Size(), "expected mempool to be at max capacity limit")
   208  		require.Equal(t, uint64(0), atomic.LoadUint64(&calls))
   209  	})
   210  
   211  	t.Run("fill mempool up to limit", func(t *testing.T) {
   212  		atomic.StoreUint64(&calls, uint64(0))
   213  		pool.RegisterEjectionCallbacks(callback) // now we have registered the callback three times
   214  		addRandomEntities(t, pool, 7)            // => we should receive 3 calls per ejection
   215  		require.Less(t, uint(limit), pool.Size(), "expected mempool to be at max capacity limit")
   216  		require.Equal(t, uint64(0), atomic.LoadUint64(&calls))
   217  	})
   218  }
   219  
   220  func addRandomEntities(t *testing.T, backend *stdmap.Backend, num int) {
   221  	// add swarm-number of items to backend
   222  	wg := sync.WaitGroup{}
   223  	wg.Add(num)
   224  	for ; num > 0; num-- {
   225  		go func() {
   226  			backend.Add(unittest.MockEntityFixture()) // creates and adds a fake item to the mempool
   227  			wg.Done()
   228  		}()
   229  	}
   230  	unittest.RequireReturnsBefore(t, wg.Wait, 1*time.Second, "failed to add elements in time")
   231  }
   232  
   233  func TestBackend_All(t *testing.T) {
   234  	backend := stdmap.NewBackend()
   235  	entities := unittest.EntityListFixture(100)
   236  
   237  	// Add
   238  	for _, e := range entities {
   239  		// all entities must be stored successfully
   240  		require.True(t, backend.Add(e))
   241  	}
   242  
   243  	// All
   244  	all := backend.All()
   245  	require.Equal(t, len(entities), len(all))
   246  	for _, expected := range entities {
   247  		actual, ok := backend.ByID(expected.ID())
   248  		require.True(t, ok)
   249  		require.Equal(t, expected, actual)
   250  	}
   251  }