github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/module/mempool/stdmap/backend_test.go (about)

     1  package stdmap_test
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  	"sync/atomic"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/onflow/flow-go/model/flow"
    14  	"github.com/onflow/flow-go/module/mempool"
    15  	herocache "github.com/onflow/flow-go/module/mempool/herocache/backdata"
    16  	"github.com/onflow/flow-go/module/mempool/herocache/backdata/heropool"
    17  	"github.com/onflow/flow-go/module/mempool/stdmap"
    18  	"github.com/onflow/flow-go/module/metrics"
    19  	"github.com/onflow/flow-go/utils/unittest"
    20  )
    21  
    22  func TestAddRemove(t *testing.T) {
    23  	item1 := unittest.MockEntityFixture()
    24  	item2 := unittest.MockEntityFixture()
    25  
    26  	t.Run("should be able to add and rem", func(t *testing.T) {
    27  		pool := stdmap.NewBackend()
    28  		added := pool.Add(item1)
    29  		require.True(t, added)
    30  		added = pool.Add(item2)
    31  		require.True(t, added)
    32  
    33  		t.Run("should be able to get size", func(t *testing.T) {
    34  			size := pool.Size()
    35  			assert.EqualValues(t, uint(2), size)
    36  		})
    37  
    38  		t.Run("should be able to get first", func(t *testing.T) {
    39  			gotItem, exists := pool.ByID(item1.ID())
    40  			assert.True(t, exists)
    41  			assert.Equal(t, item1, gotItem)
    42  		})
    43  
    44  		t.Run("should be able to remove first", func(t *testing.T) {
    45  			removed := pool.Remove(item1.ID())
    46  			assert.True(t, removed)
    47  			size := pool.Size()
    48  			assert.EqualValues(t, uint(1), size)
    49  		})
    50  
    51  		t.Run("should be able to retrieve all", func(t *testing.T) {
    52  			items := pool.All()
    53  			require.Len(t, items, 1)
    54  			assert.Equal(t, item2, items[0])
    55  		})
    56  	})
    57  }
    58  
    59  func TestAdjust(t *testing.T) {
    60  	item1 := unittest.MockEntityFixture()
    61  	item2 := unittest.MockEntityFixture()
    62  
    63  	t.Run("should not adjust if not exist", func(t *testing.T) {
    64  		pool := stdmap.NewBackend()
    65  		_ = pool.Add(item1)
    66  
    67  		// item2 doesn't exist
    68  		updatedItem, updated := pool.Adjust(item2.ID(), func(old flow.Entity) flow.Entity {
    69  			return item2
    70  		})
    71  
    72  		assert.False(t, updated)
    73  		assert.Nil(t, updatedItem)
    74  
    75  		_, found := pool.ByID(item2.ID())
    76  		assert.False(t, found)
    77  	})
    78  
    79  	t.Run("should adjust if exists", func(t *testing.T) {
    80  		pool := stdmap.NewBackend()
    81  		_ = pool.Add(item1)
    82  
    83  		updatedItem, ok := pool.Adjust(item1.ID(), func(old flow.Entity) flow.Entity {
    84  			// item 1 exist, got replaced with item2, the value was updated
    85  			return item2
    86  		})
    87  
    88  		assert.True(t, ok)
    89  		assert.Equal(t, updatedItem, item2)
    90  
    91  		value2, found := pool.ByID(item2.ID())
    92  		assert.True(t, found)
    93  		assert.Equal(t, value2, item2)
    94  	})
    95  }
    96  
    97  // Test that size mempool de-duplicates based on ID
    98  func Test_DeduplicationByID(t *testing.T) {
    99  	item1 := unittest.MockEntityFixture()
   100  	item2 := unittest.MockEntity{Identifier: item1.Identifier} // duplicate
   101  	assert.True(t, item1.ID() == item2.ID())
   102  
   103  	pool := stdmap.NewBackend()
   104  	pool.Add(item1)
   105  	pool.Add(item2)
   106  	assert.Equal(t, uint(1), pool.Size())
   107  }
   108  
   109  // TestBackend_RunLimitChecking defines a backend with size limit of `limit`. It then
   110  // starts adding `swarm`-many items concurrently to the backend each on a separate goroutine,
   111  // where `swarm` > `limit`,
   112  // and evaluates that size of the map stays within the limit.
   113  func TestBackend_RunLimitChecking(t *testing.T) {
   114  	const (
   115  		limit = 150
   116  		swarm = 150
   117  	)
   118  	pool := stdmap.NewBackend(stdmap.WithLimit(limit))
   119  
   120  	wg := sync.WaitGroup{}
   121  	wg.Add(swarm)
   122  
   123  	for i := 0; i < swarm; i++ {
   124  		go func(x int) {
   125  			// creates and adds a fake item to the mempool
   126  			item := unittest.MockEntityFixture()
   127  			_ = pool.Run(func(backdata mempool.BackData) error {
   128  				added := backdata.Add(item.ID(), item)
   129  				if !added {
   130  					return fmt.Errorf("potential race condition on adding to back data")
   131  				}
   132  
   133  				return nil
   134  			})
   135  
   136  			// evaluates that the size remains in the permissible range
   137  			require.True(t, pool.Size() <= uint(limit),
   138  				fmt.Sprintf("size violation: should be at most: %d, got: %d", limit, pool.Size()))
   139  			wg.Done()
   140  		}(i)
   141  	}
   142  
   143  	unittest.RequireReturnsBefore(t, wg.Wait, 1*time.Second, "test could not finish on time")
   144  }
   145  
   146  // TestBackend_RegisterEjectionCallback verifies that the Backend calls the
   147  // ejection callbacks whenever it ejects a stored entity due to size limitations.
   148  func TestBackend_RegisterEjectionCallback(t *testing.T) {
   149  	const (
   150  		limit = 20
   151  		swarm = 20
   152  	)
   153  	pool := stdmap.NewBackend(stdmap.WithLimit(limit))
   154  
   155  	// on ejection callback: test whether ejected identity is no longer part of the mempool
   156  	ensureEntityNotInMempool := func(entity flow.Entity) {
   157  		id := entity.ID()
   158  		go func() {
   159  			e, found := pool.ByID(id)
   160  			require.False(t, found)
   161  			require.Nil(t, e)
   162  		}()
   163  		go func() {
   164  			require.False(t, pool.Has(id))
   165  		}()
   166  	}
   167  	pool.RegisterEjectionCallbacks(ensureEntityNotInMempool)
   168  
   169  	wg := sync.WaitGroup{}
   170  	wg.Add(swarm)
   171  	for i := 0; i < swarm; i++ {
   172  		go func(x int) {
   173  			// creates and adds a fake item to the mempool
   174  			item := unittest.MockEntityFixture()
   175  			pool.Add(item)
   176  			wg.Done()
   177  		}(i)
   178  	}
   179  
   180  	unittest.RequireReturnsBefore(t, wg.Wait, 1*time.Second, "test could not finish on time")
   181  	require.Equal(t, uint(limit), pool.Size(), "expected mempool to be at max capacity limit")
   182  }
   183  
   184  // TestBackend_Multiple_OnEjectionCallbacks verifies that the Backend
   185  // handles multiple ejection callbacks correctly
   186  func TestBackend_Multiple_OnEjectionCallbacks(t *testing.T) {
   187  	// ejection callback counts number of calls
   188  	calls := uint64(0)
   189  	callback := func(entity flow.Entity) {
   190  		atomic.AddUint64(&calls, 1)
   191  	}
   192  
   193  	// construct backend
   194  	const (
   195  		limit = 30
   196  	)
   197  	pool := stdmap.NewBackend(stdmap.WithLimit(limit))
   198  	pool.RegisterEjectionCallbacks(callback, callback)
   199  
   200  	t.Run("fill mempool up to limit", func(t *testing.T) {
   201  		addRandomEntities(t, pool, limit)
   202  		require.Equal(t, uint(limit), pool.Size(), "expected mempool to be at max capacity limit")
   203  		require.Equal(t, uint64(0), atomic.LoadUint64(&calls))
   204  	})
   205  
   206  	t.Run("add elements beyond limit", func(t *testing.T) {
   207  		addRandomEntities(t, pool, 2) // as we registered callback _twice_, we should receive 2 calls per ejection
   208  		require.Less(t, uint(limit), pool.Size(), "expected mempool to be at max capacity limit")
   209  		require.Equal(t, uint64(0), atomic.LoadUint64(&calls))
   210  	})
   211  
   212  	t.Run("fill mempool up to limit", func(t *testing.T) {
   213  		atomic.StoreUint64(&calls, uint64(0))
   214  		pool.RegisterEjectionCallbacks(callback) // now we have registered the callback three times
   215  		addRandomEntities(t, pool, 7)            // => we should receive 3 calls per ejection
   216  		require.Less(t, uint(limit), pool.Size(), "expected mempool to be at max capacity limit")
   217  		require.Equal(t, uint64(0), atomic.LoadUint64(&calls))
   218  	})
   219  }
   220  
   221  // TestBackend_AdjustWithInit_Concurrent tests the AdjustWithInit method of the Backend with HeroCache as the backdata.
   222  // It concurrently attempts on adjusting non-existent entities, and verifies that the entities are initialized and adjusted correctly.
   223  func TestBackend_AdjustWithInit_Concurrent_HeroCache(t *testing.T) {
   224  	sizeLimit := uint32(100)
   225  	backData := herocache.NewCache(sizeLimit,
   226  		herocache.DefaultOversizeFactor,
   227  		heropool.LRUEjection,
   228  		unittest.Logger(),
   229  		metrics.NewNoopCollector())
   230  
   231  	backend := stdmap.NewBackend(stdmap.WithBackData(backData))
   232  	entities := unittest.EntityListFixture(100)
   233  	adjustDone := sync.WaitGroup{}
   234  	for _, e := range entities {
   235  		adjustDone.Add(1)
   236  		e := e // capture range variable
   237  		go func() {
   238  			defer adjustDone.Done()
   239  
   240  			backend.AdjustWithInit(e.ID(), func(entity flow.Entity) flow.Entity {
   241  				// increment nonce of the entity
   242  				mockEntity, ok := entity.(*unittest.MockEntity)
   243  				require.True(t, ok)
   244  				mockEntity.Nonce++
   245  				return entity
   246  			}, func() flow.Entity {
   247  				return e
   248  			})
   249  		}()
   250  	}
   251  
   252  	unittest.RequireReturnsBefore(t, adjustDone.Wait, 1*time.Second, "failed to adjust elements in time")
   253  
   254  	for _, e := range entities {
   255  		actual, ok := backend.ByID(e.ID())
   256  		require.True(t, ok)
   257  		require.Equal(t, e.ID(), actual.ID())
   258  		require.Equal(t, uint64(1), actual.(*unittest.MockEntity).Nonce)
   259  	}
   260  }
   261  
   262  // TestBackend_GetWithInit_Concurrent tests the GetWithInit method of the Backend with HeroCache as the backdata.
   263  // It concurrently attempts on adjusting non-existent entities, and verifies that the entities are initialized and retrieved correctly.
   264  func TestBackend_GetWithInit_Concurrent_HeroCache(t *testing.T) {
   265  	sizeLimit := uint32(100)
   266  	backData := herocache.NewCache(sizeLimit, herocache.DefaultOversizeFactor, heropool.LRUEjection, unittest.Logger(), metrics.NewNoopCollector())
   267  
   268  	backend := stdmap.NewBackend(stdmap.WithBackData(backData))
   269  	entities := unittest.EntityListFixture(100)
   270  	adjustDone := sync.WaitGroup{}
   271  	for _, e := range entities {
   272  		adjustDone.Add(1)
   273  		e := e // capture range variable
   274  		go func() {
   275  			defer adjustDone.Done()
   276  
   277  			entity, ok := backend.GetWithInit(e.ID(), func() flow.Entity {
   278  				return e
   279  			})
   280  			require.True(t, ok)
   281  			require.Equal(t, e.ID(), entity.ID())
   282  		}()
   283  	}
   284  
   285  	unittest.RequireReturnsBefore(t, adjustDone.Wait, 1*time.Second, "failed to get-with-init elements in time")
   286  
   287  	for _, e := range entities {
   288  		actual, ok := backend.ByID(e.ID())
   289  		require.True(t, ok)
   290  		require.Equal(t, e.ID(), actual.ID())
   291  	}
   292  }
   293  
   294  // TestBackend_AdjustWithInit_Concurrent_MapBased tests the AdjustWithInit method of the Backend with golang map as the backdata.
   295  // It concurrently attempts on adjusting non-existent entities, and verifies that the entities are initialized and adjusted correctly.
   296  func TestBackend_AdjustWithInit_Concurrent_MapBased(t *testing.T) {
   297  	sizeLimit := uint(100)
   298  	backend := stdmap.NewBackend(stdmap.WithLimit(sizeLimit))
   299  	entities := unittest.EntityListFixture(sizeLimit)
   300  
   301  	adjustDone := sync.WaitGroup{}
   302  	for _, e := range entities {
   303  		adjustDone.Add(1)
   304  		e := e // capture range variable
   305  		go func() {
   306  			defer adjustDone.Done()
   307  
   308  			backend.AdjustWithInit(e.ID(), func(entity flow.Entity) flow.Entity {
   309  				// increment nonce of the entity
   310  				mockEntity, ok := entity.(*unittest.MockEntity)
   311  				require.True(t, ok)
   312  				mockEntity.Nonce++
   313  				return entity
   314  			}, func() flow.Entity {
   315  				return e
   316  			})
   317  		}()
   318  	}
   319  
   320  	unittest.RequireReturnsBefore(t, adjustDone.Wait, 1*time.Second, "failed to adjust elements in time")
   321  
   322  	for _, e := range entities {
   323  		actual, ok := backend.ByID(e.ID())
   324  		require.True(t, ok)
   325  		require.Equal(t, e.ID(), actual.ID())
   326  		require.Equal(t, uint64(1), actual.(*unittest.MockEntity).Nonce)
   327  	}
   328  }
   329  
   330  // TestBackend_GetWithInit_Concurrentt_MapBased tests the GetWithInit method of the Backend with golang map as the backdata.
   331  // It concurrently attempts on adjusting non-existent entities, and verifies that the entities are initialized and retrieved correctly.
   332  func TestBackend_GetWithInit_Concurrent_MapBased(t *testing.T) {
   333  	sizeLimit := uint(100)
   334  	backend := stdmap.NewBackend(stdmap.WithLimit(sizeLimit))
   335  	entities := unittest.EntityListFixture(100)
   336  	adjustDone := sync.WaitGroup{}
   337  	for _, e := range entities {
   338  		adjustDone.Add(1)
   339  		e := e // capture range variable
   340  		go func() {
   341  			defer adjustDone.Done()
   342  
   343  			entity, ok := backend.GetWithInit(e.ID(), func() flow.Entity {
   344  				return e
   345  			})
   346  			require.True(t, ok)
   347  			require.Equal(t, e.ID(), entity.ID())
   348  		}()
   349  	}
   350  
   351  	unittest.RequireReturnsBefore(t, adjustDone.Wait, 1*time.Second, "failed to get-with-init elements in time")
   352  
   353  	for _, e := range entities {
   354  		actual, ok := backend.ByID(e.ID())
   355  		require.True(t, ok)
   356  		require.Equal(t, e.ID(), actual.ID())
   357  	}
   358  }
   359  
   360  func addRandomEntities(t *testing.T, backend *stdmap.Backend, num int) {
   361  	// add swarm-number of items to backend
   362  	wg := sync.WaitGroup{}
   363  	wg.Add(num)
   364  	for ; num > 0; num-- {
   365  		go func() {
   366  			defer wg.Done()
   367  			backend.Add(unittest.MockEntityFixture()) // creates and adds a fake item to the mempool
   368  		}()
   369  	}
   370  	unittest.RequireReturnsBefore(t, wg.Wait, 1*time.Second, "failed to add elements in time")
   371  }
   372  
   373  func TestBackend_All(t *testing.T) {
   374  	backend := stdmap.NewBackend()
   375  	entities := unittest.EntityListFixture(100)
   376  
   377  	// Add
   378  	for _, e := range entities {
   379  		// all entities must be stored successfully
   380  		require.True(t, backend.Add(e))
   381  	}
   382  
   383  	// All
   384  	all := backend.All()
   385  	require.Equal(t, len(entities), len(all))
   386  	for _, expected := range entities {
   387  		actual, ok := backend.ByID(expected.ID())
   388  		require.True(t, ok)
   389  		require.Equal(t, expected, actual)
   390  	}
   391  }