github.com/onflow/flow-go@v0.33.17/module/mempool/stdmap/backend_test.go (about)

     1  // (c) 2019 Dapper Labs - ALL RIGHTS RESERVED
     2  
     3  package stdmap_test
     4  
     5  import (
     6  	"fmt"
     7  	"sync"
     8  	"sync/atomic"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  
    15  	"github.com/onflow/flow-go/model/flow"
    16  	"github.com/onflow/flow-go/module/mempool"
    17  	herocache "github.com/onflow/flow-go/module/mempool/herocache/backdata"
    18  	"github.com/onflow/flow-go/module/mempool/herocache/backdata/heropool"
    19  	"github.com/onflow/flow-go/module/mempool/stdmap"
    20  	"github.com/onflow/flow-go/module/metrics"
    21  	"github.com/onflow/flow-go/utils/unittest"
    22  )
    23  
    24  func TestAddRemove(t *testing.T) {
    25  	item1 := unittest.MockEntityFixture()
    26  	item2 := unittest.MockEntityFixture()
    27  
    28  	t.Run("should be able to add and rem", func(t *testing.T) {
    29  		pool := stdmap.NewBackend()
    30  		added := pool.Add(item1)
    31  		require.True(t, added)
    32  		added = pool.Add(item2)
    33  		require.True(t, added)
    34  
    35  		t.Run("should be able to get size", func(t *testing.T) {
    36  			size := pool.Size()
    37  			assert.EqualValues(t, uint(2), size)
    38  		})
    39  
    40  		t.Run("should be able to get first", func(t *testing.T) {
    41  			gotItem, exists := pool.ByID(item1.ID())
    42  			assert.True(t, exists)
    43  			assert.Equal(t, item1, gotItem)
    44  		})
    45  
    46  		t.Run("should be able to remove first", func(t *testing.T) {
    47  			removed := pool.Remove(item1.ID())
    48  			assert.True(t, removed)
    49  			size := pool.Size()
    50  			assert.EqualValues(t, uint(1), size)
    51  		})
    52  
    53  		t.Run("should be able to retrieve all", func(t *testing.T) {
    54  			items := pool.All()
    55  			require.Len(t, items, 1)
    56  			assert.Equal(t, item2, items[0])
    57  		})
    58  	})
    59  }
    60  
    61  func TestAdjust(t *testing.T) {
    62  	item1 := unittest.MockEntityFixture()
    63  	item2 := unittest.MockEntityFixture()
    64  
    65  	t.Run("should not adjust if not exist", func(t *testing.T) {
    66  		pool := stdmap.NewBackend()
    67  		_ = pool.Add(item1)
    68  
    69  		// item2 doesn't exist
    70  		updatedItem, updated := pool.Adjust(item2.ID(), func(old flow.Entity) flow.Entity {
    71  			return item2
    72  		})
    73  
    74  		assert.False(t, updated)
    75  		assert.Nil(t, updatedItem)
    76  
    77  		_, found := pool.ByID(item2.ID())
    78  		assert.False(t, found)
    79  	})
    80  
    81  	t.Run("should adjust if exists", func(t *testing.T) {
    82  		pool := stdmap.NewBackend()
    83  		_ = pool.Add(item1)
    84  
    85  		updatedItem, ok := pool.Adjust(item1.ID(), func(old flow.Entity) flow.Entity {
    86  			// item 1 exist, got replaced with item2, the value was updated
    87  			return item2
    88  		})
    89  
    90  		assert.True(t, ok)
    91  		assert.Equal(t, updatedItem, item2)
    92  
    93  		value2, found := pool.ByID(item2.ID())
    94  		assert.True(t, found)
    95  		assert.Equal(t, value2, item2)
    96  	})
    97  }
    98  
    99  // Test that size mempool de-duplicates based on ID
   100  func Test_DeduplicationByID(t *testing.T) {
   101  	item1 := unittest.MockEntityFixture()
   102  	item2 := unittest.MockEntity{Identifier: item1.Identifier} // duplicate
   103  	assert.True(t, item1.ID() == item2.ID())
   104  
   105  	pool := stdmap.NewBackend()
   106  	pool.Add(item1)
   107  	pool.Add(item2)
   108  	assert.Equal(t, uint(1), pool.Size())
   109  }
   110  
   111  // TestBackend_RunLimitChecking defines a backend with size limit of `limit`. It then
   112  // starts adding `swarm`-many items concurrently to the backend each on a separate goroutine,
   113  // where `swarm` > `limit`,
   114  // and evaluates that size of the map stays within the limit.
   115  func TestBackend_RunLimitChecking(t *testing.T) {
   116  	const (
   117  		limit = 150
   118  		swarm = 150
   119  	)
   120  	pool := stdmap.NewBackend(stdmap.WithLimit(limit))
   121  
   122  	wg := sync.WaitGroup{}
   123  	wg.Add(swarm)
   124  
   125  	for i := 0; i < swarm; i++ {
   126  		go func(x int) {
   127  			defer wg.Done()
   128  
   129  			// creates and adds a fake item to the mempool
   130  			item := unittest.MockEntityFixture()
   131  			_ = pool.Run(func(backdata mempool.BackData) error {
   132  				added := backdata.Add(item.ID(), item)
   133  				if !added {
   134  					return fmt.Errorf("potential race condition on adding to back data")
   135  				}
   136  
   137  				return nil
   138  			})
   139  
   140  			// evaluates that the size remains in the permissible range
   141  			require.True(t, pool.Size() <= uint(limit),
   142  				fmt.Sprintf("size violation: should be at most: %d, got: %d", limit, pool.Size()))
   143  		}(i)
   144  	}
   145  
   146  	unittest.RequireReturnsBefore(t, wg.Wait, 1*time.Second, "test could not finish on time")
   147  }
   148  
   149  // TestBackend_RegisterEjectionCallback verifies that the Backend calls the
   150  // ejection callbacks whenever it ejects a stored entity due to size limitations.
   151  func TestBackend_RegisterEjectionCallback(t *testing.T) {
   152  	const (
   153  		limit = 20
   154  		swarm = 20
   155  	)
   156  	pool := stdmap.NewBackend(stdmap.WithLimit(limit))
   157  
   158  	// on ejection callback: test whether ejected identity is no longer part of the mempool
   159  	ensureEntityNotInMempool := func(entity flow.Entity) {
   160  		id := entity.ID()
   161  		go func() {
   162  			e, found := pool.ByID(id)
   163  			require.False(t, found)
   164  			require.Nil(t, e)
   165  		}()
   166  		go func() {
   167  			require.False(t, pool.Has(id))
   168  		}()
   169  	}
   170  	pool.RegisterEjectionCallbacks(ensureEntityNotInMempool)
   171  
   172  	wg := sync.WaitGroup{}
   173  	wg.Add(swarm)
   174  	for i := 0; i < swarm; i++ {
   175  		go func(x int) {
   176  			defer wg.Done()
   177  			// creates and adds a fake item to the mempool
   178  			item := unittest.MockEntityFixture()
   179  			pool.Add(item)
   180  		}(i)
   181  	}
   182  
   183  	unittest.RequireReturnsBefore(t, wg.Wait, 1*time.Second, "test could not finish on time")
   184  	require.Equal(t, uint(limit), pool.Size(), "expected mempool to be at max capacity limit")
   185  }
   186  
   187  // TestBackend_Multiple_OnEjectionCallbacks verifies that the Backend
   188  // handles multiple ejection callbacks correctly
   189  func TestBackend_Multiple_OnEjectionCallbacks(t *testing.T) {
   190  	// ejection callback counts number of calls
   191  	calls := uint64(0)
   192  	callback := func(entity flow.Entity) {
   193  		atomic.AddUint64(&calls, 1)
   194  	}
   195  
   196  	// construct backend
   197  	const (
   198  		limit = 30
   199  	)
   200  	pool := stdmap.NewBackend(stdmap.WithLimit(limit))
   201  	pool.RegisterEjectionCallbacks(callback, callback)
   202  
   203  	t.Run("fill mempool up to limit", func(t *testing.T) {
   204  		addRandomEntities(t, pool, limit)
   205  		require.Equal(t, uint(limit), pool.Size(), "expected mempool to be at max capacity limit")
   206  		require.Equal(t, uint64(0), atomic.LoadUint64(&calls))
   207  	})
   208  
   209  	t.Run("add elements beyond limit", func(t *testing.T) {
   210  		addRandomEntities(t, pool, 2) // as we registered callback _twice_, we should receive 2 calls per ejection
   211  		require.Less(t, uint(limit), pool.Size(), "expected mempool to be at max capacity limit")
   212  		require.Equal(t, uint64(0), atomic.LoadUint64(&calls))
   213  	})
   214  
   215  	t.Run("fill mempool up to limit", func(t *testing.T) {
   216  		atomic.StoreUint64(&calls, uint64(0))
   217  		pool.RegisterEjectionCallbacks(callback) // now we have registered the callback three times
   218  		addRandomEntities(t, pool, 7)            // => we should receive 3 calls per ejection
   219  		require.Less(t, uint(limit), pool.Size(), "expected mempool to be at max capacity limit")
   220  		require.Equal(t, uint64(0), atomic.LoadUint64(&calls))
   221  	})
   222  }
   223  
   224  // TestBackend_AdjustWithInit_Concurrent tests the AdjustWithInit method of the Backend with HeroCache as the backdata.
   225  // It concurrently attempts on adjusting non-existent entities, and verifies that the entities are initialized and adjusted correctly.
   226  func TestBackend_AdjustWithInit_Concurrent_HeroCache(t *testing.T) {
   227  	sizeLimit := uint32(100)
   228  	backData := herocache.NewCache(sizeLimit,
   229  		herocache.DefaultOversizeFactor,
   230  		heropool.LRUEjection,
   231  		unittest.Logger(),
   232  		metrics.NewNoopCollector())
   233  
   234  	backend := stdmap.NewBackend(stdmap.WithBackData(backData))
   235  	entities := unittest.EntityListFixture(100)
   236  	adjustDone := sync.WaitGroup{}
   237  	for _, e := range entities {
   238  		adjustDone.Add(1)
   239  		e := e // capture range variable
   240  		go func() {
   241  			defer adjustDone.Done()
   242  
   243  			backend.AdjustWithInit(e.ID(), func(entity flow.Entity) flow.Entity {
   244  				// increment nonce of the entity
   245  				mockEntity, ok := entity.(*unittest.MockEntity)
   246  				require.True(t, ok)
   247  				mockEntity.Nonce++
   248  				return entity
   249  			}, func() flow.Entity {
   250  				return e
   251  			})
   252  		}()
   253  	}
   254  
   255  	unittest.RequireReturnsBefore(t, adjustDone.Wait, 1*time.Second, "failed to adjust elements in time")
   256  
   257  	for _, e := range entities {
   258  		actual, ok := backend.ByID(e.ID())
   259  		require.True(t, ok)
   260  		require.Equal(t, e.ID(), actual.ID())
   261  		require.Equal(t, uint64(1), actual.(*unittest.MockEntity).Nonce)
   262  	}
   263  }
   264  
   265  // TestBackend_GetWithInit_Concurrent tests the GetWithInit method of the Backend with HeroCache as the backdata.
   266  // It concurrently attempts on adjusting non-existent entities, and verifies that the entities are initialized and retrieved correctly.
   267  func TestBackend_GetWithInit_Concurrent_HeroCache(t *testing.T) {
   268  	sizeLimit := uint32(100)
   269  	backData := herocache.NewCache(sizeLimit, herocache.DefaultOversizeFactor, heropool.LRUEjection, unittest.Logger(), metrics.NewNoopCollector())
   270  
   271  	backend := stdmap.NewBackend(stdmap.WithBackData(backData))
   272  	entities := unittest.EntityListFixture(100)
   273  	adjustDone := sync.WaitGroup{}
   274  	for _, e := range entities {
   275  		adjustDone.Add(1)
   276  		e := e // capture range variable
   277  		go func() {
   278  			defer adjustDone.Done()
   279  
   280  			entity, ok := backend.GetWithInit(e.ID(), func() flow.Entity {
   281  				return e
   282  			})
   283  			require.True(t, ok)
   284  			require.Equal(t, e.ID(), entity.ID())
   285  		}()
   286  	}
   287  
   288  	unittest.RequireReturnsBefore(t, adjustDone.Wait, 1*time.Second, "failed to get-with-init elements in time")
   289  
   290  	for _, e := range entities {
   291  		actual, ok := backend.ByID(e.ID())
   292  		require.True(t, ok)
   293  		require.Equal(t, e.ID(), actual.ID())
   294  	}
   295  }
   296  
   297  // TestBackend_AdjustWithInit_Concurrent_MapBased tests the AdjustWithInit method of the Backend with golang map as the backdata.
   298  // It concurrently attempts on adjusting non-existent entities, and verifies that the entities are initialized and adjusted correctly.
   299  func TestBackend_AdjustWithInit_Concurrent_MapBased(t *testing.T) {
   300  	unittest.SkipUnless(t, unittest.TEST_FLAKY, "flakey on CI, fix is in progress")
   301  	sizeLimit := uint(100)
   302  	backend := stdmap.NewBackend(stdmap.WithLimit(sizeLimit))
   303  	entities := unittest.EntityListFixture(sizeLimit)
   304  	adjustDone := sync.WaitGroup{}
   305  	for _, e := range entities {
   306  		adjustDone.Add(1)
   307  		e := e // capture range variable
   308  		go func() {
   309  			defer adjustDone.Done()
   310  
   311  			backend.AdjustWithInit(e.ID(), func(entity flow.Entity) flow.Entity {
   312  				// increment nonce of the entity
   313  				mockEntity, ok := entity.(*unittest.MockEntity)
   314  				require.True(t, ok)
   315  				mockEntity.Nonce++
   316  				return entity
   317  			}, func() flow.Entity {
   318  				return e
   319  			})
   320  		}()
   321  	}
   322  
   323  	unittest.RequireReturnsBefore(t, adjustDone.Wait, 1*time.Second, "failed to adjust elements in time")
   324  
   325  	for _, e := range entities {
   326  		actual, ok := backend.ByID(e.ID())
   327  		require.True(t, ok)
   328  		require.Equal(t, e.ID(), actual.ID())
   329  		require.Equal(t, uint64(1), actual.(*unittest.MockEntity).Nonce)
   330  	}
   331  }
   332  
   333  // TestBackend_GetWithInit_Concurrentt_MapBased tests the GetWithInit method of the Backend with golang map as the backdata.
   334  // It concurrently attempts on adjusting non-existent entities, and verifies that the entities are initialized and retrieved correctly.
   335  func TestBackend_GetWithInit_Concurrent_MapBased(t *testing.T) {
   336  	sizeLimit := uint(100)
   337  	backend := stdmap.NewBackend(stdmap.WithLimit(sizeLimit))
   338  	entities := unittest.EntityListFixture(100)
   339  	adjustDone := sync.WaitGroup{}
   340  	for _, e := range entities {
   341  		adjustDone.Add(1)
   342  		e := e // capture range variable
   343  		go func() {
   344  			defer adjustDone.Done()
   345  
   346  			entity, ok := backend.GetWithInit(e.ID(), func() flow.Entity {
   347  				return e
   348  			})
   349  			require.True(t, ok)
   350  			require.Equal(t, e.ID(), entity.ID())
   351  		}()
   352  	}
   353  
   354  	unittest.RequireReturnsBefore(t, adjustDone.Wait, 1*time.Second, "failed to get-with-init elements in time")
   355  
   356  	for _, e := range entities {
   357  		actual, ok := backend.ByID(e.ID())
   358  		require.True(t, ok)
   359  		require.Equal(t, e.ID(), actual.ID())
   360  	}
   361  }
   362  
   363  func addRandomEntities(t *testing.T, backend *stdmap.Backend, num int) {
   364  	// add swarm-number of items to backend
   365  	wg := sync.WaitGroup{}
   366  	wg.Add(num)
   367  	for ; num > 0; num-- {
   368  		go func() {
   369  			defer wg.Done()
   370  			backend.Add(unittest.MockEntityFixture()) // creates and adds a fake item to the mempool
   371  		}()
   372  	}
   373  	unittest.RequireReturnsBefore(t, wg.Wait, 1*time.Second, "failed to add elements in time")
   374  }
   375  
   376  func TestBackend_All(t *testing.T) {
   377  	backend := stdmap.NewBackend()
   378  	entities := unittest.EntityListFixture(100)
   379  
   380  	// Add
   381  	for _, e := range entities {
   382  		// all entities must be stored successfully
   383  		require.True(t, backend.Add(e))
   384  	}
   385  
   386  	// All
   387  	all := backend.All()
   388  	require.Equal(t, len(entities), len(all))
   389  	for _, expected := range entities {
   390  		actual, ok := backend.ByID(expected.ID())
   391  		require.True(t, ok)
   392  		require.Equal(t, expected, actual)
   393  	}
   394  }