github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/network/alsp/internal/cache_test.go (about)

     1  package internal_test
     2  
     3  import (
     4  	"errors"
     5  	"sync"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/rs/zerolog"
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/onflow/flow-go/model/flow"
    13  	"github.com/onflow/flow-go/module/metrics"
    14  	"github.com/onflow/flow-go/network/alsp/internal"
    15  	"github.com/onflow/flow-go/network/alsp/model"
    16  	"github.com/onflow/flow-go/utils/unittest"
    17  )
    18  
    19  // TestNewSpamRecordCache tests the creation of a new SpamRecordCache.
    20  // It ensures that the returned cache is not nil. It does not test the
    21  // functionality of the cache.
    22  func TestNewSpamRecordCache(t *testing.T) {
    23  	sizeLimit := uint32(100)
    24  	logger := zerolog.Nop()
    25  	collector := metrics.NewNoopCollector()
    26  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
    27  		return protocolSpamRecordFixture(id)
    28  	}
    29  
    30  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
    31  	require.NotNil(t, cache)
    32  	require.Equalf(t, uint(0), cache.Size(), "cache size must be 0")
    33  }
    34  
    35  // protocolSpamRecordFixture creates a new protocol spam record with the given origin id.
    36  // Args:
    37  // - id: the origin id of the spam record.
    38  // Returns:
    39  // - alsp.ProtocolSpamRecord, the created spam record.
    40  // Note that the returned spam record is not a valid spam record. It is used only for testing.
    41  func protocolSpamRecordFixture(id flow.Identifier) model.ProtocolSpamRecord {
    42  	return model.ProtocolSpamRecord{
    43  		OriginId:      id,
    44  		Decay:         1000,
    45  		CutoffCounter: 0,
    46  		Penalty:       0,
    47  	}
    48  }
    49  
    50  // TestSpamRecordCache_Adjust_Init tests that when the Adjust function is called
    51  // on a record that does not exist in the cache, the record is initialized and
    52  // the adjust function is applied to the initialized record.
    53  func TestSpamRecordCache_Adjust_Init(t *testing.T) {
    54  	sizeLimit := uint32(100)
    55  	logger := zerolog.Nop()
    56  	collector := metrics.NewNoopCollector()
    57  
    58  	recordFactoryCalled := 0
    59  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
    60  		require.Less(t, recordFactoryCalled, 2, "record factory must be called only twice")
    61  		return protocolSpamRecordFixture(id)
    62  	}
    63  	adjustFuncIncrement := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
    64  		record.Penalty += 1
    65  		return record, nil
    66  	}
    67  
    68  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
    69  	require.NotNil(t, cache)
    70  	require.Zerof(t, cache.Size(), "expected cache to be empty")
    71  
    72  	originID1 := unittest.IdentifierFixture()
    73  	originID2 := unittest.IdentifierFixture()
    74  
    75  	// adjusting a spam record for an origin ID that does not exist in the cache should initialize the record.
    76  	initializedPenalty, err := cache.AdjustWithInit(originID1, adjustFuncIncrement)
    77  	require.NoError(t, err, "expected no error")
    78  	require.Equal(t, float64(1), initializedPenalty, "expected initialized penalty to be 1")
    79  
    80  	record1, ok := cache.Get(originID1)
    81  	require.True(t, ok, "expected record to exist")
    82  	require.NotNil(t, record1, "expected non-nil record")
    83  	require.Equal(t, originID1, record1.OriginId, "expected record to have correct origin ID")
    84  	require.False(t, record1.DisallowListed, "expected record to not be disallow listed")
    85  	require.Equal(t, cache.Size(), uint(1), "expected cache to have one record")
    86  
    87  	// adjusting a spam record for an origin ID that already exists in the cache should not initialize the record,
    88  	// but should apply the adjust function to the existing record.
    89  	initializedPenalty, err = cache.AdjustWithInit(originID1, adjustFuncIncrement)
    90  	require.NoError(t, err, "expected no error")
    91  	require.Equal(t, float64(2), initializedPenalty, "expected initialized penalty to be 2")
    92  	record1Again, ok := cache.Get(originID1)
    93  	require.True(t, ok, "expected record to still exist")
    94  	require.NotNil(t, record1Again, "expected non-nil record")
    95  	require.Equal(t, originID1, record1Again.OriginId, "expected record to have correct origin ID")
    96  	require.False(t, record1Again.DisallowListed, "expected record not to be disallow listed")
    97  	require.Equal(t, cache.Size(), uint(1), "expected cache to still have one record")
    98  
    99  	// adjusting a spam record for a different origin ID should initialize the record.
   100  	// this is to ensure that the record factory is called only once.
   101  	initializedPenalty, err = cache.AdjustWithInit(originID2, adjustFuncIncrement)
   102  	require.NoError(t, err, "expected no error")
   103  	require.Equal(t, float64(1), initializedPenalty, "expected initialized penalty to be 1")
   104  	record2, ok := cache.Get(originID2)
   105  	require.True(t, ok, "expected record to exist")
   106  	require.NotNil(t, record2, "expected non-nil record")
   107  	require.Equal(t, originID2, record2.OriginId, "expected record to have correct origin ID")
   108  	require.False(t, record2.DisallowListed, "expected record not to be disallow listed")
   109  	require.Equal(t, cache.Size(), uint(2), "expected cache to have two records")
   110  }
   111  
   112  // TestSpamRecordCache_Adjust tests the Adjust method of the SpamRecordCache.
   113  // The test covers the following scenarios:
   114  // 1. Adjusting a spam record for an existing origin ID.
   115  // 2. Attempting to adjust a spam record with an adjustFunc that returns an error.
   116  func TestSpamRecordCache_Adjust_Error(t *testing.T) {
   117  	sizeLimit := uint32(100)
   118  	logger := zerolog.Nop()
   119  	collector := metrics.NewNoopCollector()
   120  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
   121  		return protocolSpamRecordFixture(id)
   122  	}
   123  	adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   124  		return record, nil // no-op
   125  	}
   126  
   127  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
   128  	require.NotNil(t, cache)
   129  
   130  	originID1 := unittest.IdentifierFixture()
   131  	originID2 := unittest.IdentifierFixture()
   132  
   133  	// initialize spam records for originID1 and originID2
   134  	penalty, err := cache.AdjustWithInit(originID1, adjustFnNoOp)
   135  	require.NoError(t, err, "expected no error")
   136  	require.Equal(t, 0.0, penalty, "expected penalty to be 0")
   137  	penalty, err = cache.AdjustWithInit(originID2, adjustFnNoOp)
   138  	require.NoError(t, err, "expected no error")
   139  	require.Equal(t, 0.0, penalty, "expected penalty to be 0")
   140  
   141  	// test adjusting the spam record for an existing origin ID
   142  	adjustFunc := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   143  		record.Penalty -= 10
   144  		return record, nil
   145  	}
   146  	penalty, err = cache.AdjustWithInit(originID1, adjustFunc)
   147  	require.NoError(t, err)
   148  	require.Equal(t, -10.0, penalty)
   149  
   150  	record1, ok := cache.Get(originID1)
   151  	require.True(t, ok)
   152  	require.NotNil(t, record1)
   153  	require.Equal(t, -10.0, record1.Penalty)
   154  
   155  	// test adjusting the spam record with an adjustFunc that returns an error
   156  	adjustFuncError := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   157  		return record, errors.New("adjustment error")
   158  	}
   159  	_, err = cache.AdjustWithInit(originID1, adjustFuncError)
   160  	require.Error(t, err)
   161  
   162  	// even though the adjustFunc returned an error, the record should be intact.
   163  	record1, ok = cache.Get(originID1)
   164  	require.True(t, ok)
   165  	require.NotNil(t, record1)
   166  	require.Equal(t, -10.0, record1.Penalty)
   167  }
   168  
   169  // TestSpamRecordCache_Identities tests the Identities method of the SpamRecordCache.
   170  // The test covers the following scenarios:
   171  // 1. Initializing the cache with multiple spam records.
   172  // 2. Checking if the Identities method returns the correct set of origin IDs.
   173  func TestSpamRecordCache_Identities(t *testing.T) {
   174  	sizeLimit := uint32(100)
   175  	logger := zerolog.Nop()
   176  	collector := metrics.NewNoopCollector()
   177  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
   178  		return protocolSpamRecordFixture(id)
   179  	}
   180  	adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   181  		return record, nil // no-op
   182  	}
   183  
   184  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
   185  	require.NotNil(t, cache)
   186  
   187  	originID1 := unittest.IdentifierFixture()
   188  	originID2 := unittest.IdentifierFixture()
   189  	originID3 := unittest.IdentifierFixture()
   190  
   191  	// initialize spam records for a few origin IDs
   192  	_, err := cache.AdjustWithInit(originID1, adjustFnNoOp)
   193  	require.NoError(t, err)
   194  	_, err = cache.AdjustWithInit(originID2, adjustFnNoOp)
   195  	require.NoError(t, err)
   196  	_, err = cache.AdjustWithInit(originID3, adjustFnNoOp)
   197  	require.NoError(t, err)
   198  
   199  	// check if the Identities method returns the correct set of origin IDs
   200  	identities := cache.Identities()
   201  	require.Equal(t, 3, len(identities))
   202  
   203  	identityMap := make(map[flow.Identifier]struct{})
   204  	for _, id := range identities {
   205  		identityMap[id] = struct{}{}
   206  	}
   207  
   208  	require.Contains(t, identityMap, originID1)
   209  	require.Contains(t, identityMap, originID2)
   210  	require.Contains(t, identityMap, originID3)
   211  }
   212  
   213  // TestSpamRecordCache_Remove tests the Remove method of the SpamRecordCache.
   214  // The test covers the following scenarios:
   215  // 1. Initializing the cache with multiple spam records.
   216  // 2. Removing a spam record and checking if it is removed correctly.
   217  // 3. Ensuring the other spam records are still in the cache after removal.
   218  // 4. Attempting to remove a non-existent origin ID.
   219  func TestSpamRecordCache_Remove(t *testing.T) {
   220  	sizeLimit := uint32(100)
   221  	logger := zerolog.Nop()
   222  	collector := metrics.NewNoopCollector()
   223  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
   224  		return protocolSpamRecordFixture(id)
   225  	}
   226  	adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   227  		return record, nil // no-op
   228  	}
   229  
   230  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
   231  	require.NotNil(t, cache)
   232  
   233  	originID1 := unittest.IdentifierFixture()
   234  	originID2 := unittest.IdentifierFixture()
   235  	originID3 := unittest.IdentifierFixture()
   236  
   237  	// initialize spam records for a few origin IDs
   238  	_, err := cache.AdjustWithInit(originID1, adjustFnNoOp)
   239  	require.NoError(t, err)
   240  	_, err = cache.AdjustWithInit(originID2, adjustFnNoOp)
   241  	require.NoError(t, err)
   242  	_, err = cache.AdjustWithInit(originID3, adjustFnNoOp)
   243  	require.NoError(t, err)
   244  
   245  	// remove originID1 and check if the record is removed
   246  	require.True(t, cache.Remove(originID1))
   247  	_, exists := cache.Get(originID1)
   248  	require.False(t, exists)
   249  
   250  	// check if the other origin IDs are still in the cache
   251  	_, exists = cache.Get(originID2)
   252  	require.True(t, exists)
   253  	_, exists = cache.Get(originID3)
   254  	require.True(t, exists)
   255  
   256  	// attempt to remove a non-existent origin ID
   257  	originID4 := unittest.IdentifierFixture()
   258  	require.False(t, cache.Remove(originID4))
   259  }
   260  
   261  // TestSpamRecordCache_EdgeCasesAndInvalidInputs tests the edge cases and invalid inputs for SpamRecordCache methods.
   262  // The test covers the following scenarios:
   263  // 1. Initializing a spam record multiple times.
   264  // 2. Adjusting a non-existent spam record.
   265  // 3. Removing a spam record multiple times.
   266  func TestSpamRecordCache_EdgeCasesAndInvalidInputs(t *testing.T) {
   267  	sizeLimit := uint32(100)
   268  	logger := zerolog.Nop()
   269  	collector := metrics.NewNoopCollector()
   270  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
   271  		return protocolSpamRecordFixture(id)
   272  	}
   273  	adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   274  		return record, nil // no-op
   275  	}
   276  
   277  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
   278  	require.NotNil(t, cache)
   279  
   280  	// 1. initializing a spam record multiple times
   281  	originID1 := unittest.IdentifierFixture()
   282  
   283  	_, err := cache.AdjustWithInit(originID1, adjustFnNoOp)
   284  	require.NoError(t, err)
   285  	_, err = cache.AdjustWithInit(originID1, adjustFnNoOp)
   286  	require.NoError(t, err)
   287  
   288  	// 2. Test adjusting a non-existent spam record
   289  	originID2 := unittest.IdentifierFixture()
   290  	initialPenalty, err := cache.AdjustWithInit(originID2, func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   291  		record.Penalty -= 10
   292  		return record, nil
   293  	})
   294  	require.NoError(t, err)
   295  	require.Equal(t, float64(-10), initialPenalty)
   296  
   297  	// 3. Test removing a spam record multiple times
   298  	originID3 := unittest.IdentifierFixture()
   299  	_, err = cache.AdjustWithInit(originID3, adjustFnNoOp)
   300  	require.NoError(t, err)
   301  	require.True(t, cache.Remove(originID3))
   302  	require.False(t, cache.Remove(originID3))
   303  }
   304  
   305  // TestSpamRecordCache_ConcurrentInitialization tests the concurrent initialization of spam records.
   306  // The test covers the following scenarios:
   307  // 1. Multiple goroutines initializing spam records for different origin IDs.
   308  // 2. Ensuring that all spam records are correctly initialized.
   309  func TestSpamRecordCache_ConcurrentInitialization(t *testing.T) {
   310  	sizeLimit := uint32(100)
   311  	logger := zerolog.Nop()
   312  	collector := metrics.NewNoopCollector()
   313  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
   314  		return protocolSpamRecordFixture(id)
   315  	}
   316  	adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   317  		return record, nil // no-op
   318  	}
   319  
   320  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
   321  	require.NotNil(t, cache)
   322  
   323  	originIDs := unittest.IdentifierListFixture(10)
   324  
   325  	var wg sync.WaitGroup
   326  	wg.Add(len(originIDs))
   327  
   328  	for _, originID := range originIDs {
   329  		go func(id flow.Identifier) {
   330  			defer wg.Done()
   331  			penalty, err := cache.AdjustWithInit(id, adjustFnNoOp)
   332  			require.NoError(t, err)
   333  			require.Equal(t, float64(0), penalty)
   334  		}(originID)
   335  	}
   336  
   337  	unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish")
   338  
   339  	// ensure that all spam records are correctly initialized
   340  	for _, originID := range originIDs {
   341  		record, found := cache.Get(originID)
   342  		require.True(t, found)
   343  		require.NotNil(t, record)
   344  		require.Equal(t, originID, record.OriginId)
   345  	}
   346  }
   347  
   348  // TestSpamRecordCache_ConcurrentSameRecordAdjust tests the concurrent adjust of the same spam record.
   349  // The test covers the following scenarios:
   350  // 1. Multiple goroutines attempting to adjust the same spam record concurrently.
   351  // 2. Only one of the adjust operations succeeds on initializing the record.
   352  // 3. The rest of the adjust operations only update the record (no initialization).
   353  func TestSpamRecordCache_ConcurrentSameRecordAdjust(t *testing.T) {
   354  	sizeLimit := uint32(100)
   355  	logger := zerolog.Nop()
   356  	collector := metrics.NewNoopCollector()
   357  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
   358  		return protocolSpamRecordFixture(id)
   359  	}
   360  	adjustFn := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   361  		record.Penalty -= 1.0
   362  		record.DisallowListed = true
   363  		record.Decay += 1.0
   364  		return record, nil // no-op
   365  	}
   366  
   367  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
   368  	require.NotNil(t, cache)
   369  
   370  	originID := unittest.IdentifierFixture()
   371  	const concurrentAttempts = 10
   372  
   373  	var wg sync.WaitGroup
   374  	wg.Add(concurrentAttempts)
   375  
   376  	for i := 0; i < concurrentAttempts; i++ {
   377  		go func() {
   378  			defer wg.Done()
   379  			penalty, err := cache.AdjustWithInit(originID, adjustFn)
   380  			require.NoError(t, err)
   381  			require.Less(t, penalty, 0.0) // penalty should be negative
   382  		}()
   383  	}
   384  
   385  	unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish")
   386  
   387  	// ensure that the record is correctly initialized and adjusted in the cache
   388  	initDecay := model.SpamRecordFactory()(originID).Decay
   389  	record, found := cache.Get(originID)
   390  	require.True(t, found)
   391  	require.NotNil(t, record)
   392  	require.Equal(t, concurrentAttempts*-1.0, record.Penalty)
   393  	require.Equal(t, initDecay+concurrentAttempts*1.0, record.Decay)
   394  	require.True(t, record.DisallowListed)
   395  	require.Equal(t, originID, record.OriginId)
   396  }
   397  
   398  // TestSpamRecordCache_ConcurrentRemoval tests the concurrent removal of spam records for different origin IDs.
   399  // The test covers the following scenarios:
   400  // 1. Multiple goroutines removing spam records for different origin IDs concurrently.
   401  // 2. The records are correctly removed from the cache.
   402  func TestSpamRecordCache_ConcurrentRemoval(t *testing.T) {
   403  	sizeLimit := uint32(100)
   404  	logger := zerolog.Nop()
   405  	collector := metrics.NewNoopCollector()
   406  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
   407  		return protocolSpamRecordFixture(id)
   408  	}
   409  	adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   410  		return record, nil // no-op
   411  	}
   412  
   413  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
   414  	require.NotNil(t, cache)
   415  
   416  	originIDs := unittest.IdentifierListFixture(10)
   417  	for _, originID := range originIDs {
   418  		penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp)
   419  		require.NoError(t, err)
   420  		require.Equal(t, float64(0), penalty)
   421  	}
   422  
   423  	var wg sync.WaitGroup
   424  	wg.Add(len(originIDs))
   425  
   426  	for _, originID := range originIDs {
   427  		go func(id flow.Identifier) {
   428  			defer wg.Done()
   429  			removed := cache.Remove(id)
   430  			require.True(t, removed)
   431  		}(originID)
   432  	}
   433  
   434  	unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish")
   435  
   436  	// ensure that the records are correctly removed from the cache
   437  	for _, originID := range originIDs {
   438  		_, found := cache.Get(originID)
   439  		require.False(t, found)
   440  	}
   441  
   442  	// ensure that the cache is empty
   443  	require.Equal(t, uint(0), cache.Size())
   444  }
   445  
   446  // TestSpamRecordCache_ConcurrentUpdatesAndReads tests the concurrent adjustments and reads of spam records for different
   447  // origin IDs. The test covers the following scenarios:
   448  // 1. Multiple goroutines adjusting spam records for different origin IDs concurrently.
   449  // 2. Multiple goroutines getting spam records for different origin IDs concurrently.
   450  // 3. The adjusted records are correctly updated in the cache.
   451  func TestSpamRecordCache_ConcurrentUpdatesAndReads(t *testing.T) {
   452  	sizeLimit := uint32(100)
   453  	logger := zerolog.Nop()
   454  	collector := metrics.NewNoopCollector()
   455  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
   456  		return protocolSpamRecordFixture(id)
   457  	}
   458  	adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   459  		return record, nil // no-op
   460  	}
   461  
   462  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
   463  	require.NotNil(t, cache)
   464  
   465  	originIDs := unittest.IdentifierListFixture(10)
   466  	for _, originID := range originIDs {
   467  		penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp)
   468  		require.NoError(t, err)
   469  		require.Equal(t, float64(0), penalty)
   470  	}
   471  
   472  	var wg sync.WaitGroup
   473  	wg.Add(len(originIDs) * 2)
   474  
   475  	adjustFunc := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   476  		record.Penalty -= 1
   477  		return record, nil
   478  	}
   479  
   480  	for _, originID := range originIDs {
   481  		// adjust spam records concurrently
   482  		go func(id flow.Identifier) {
   483  			defer wg.Done()
   484  			_, err := cache.AdjustWithInit(id, adjustFunc)
   485  			require.NoError(t, err)
   486  		}(originID)
   487  
   488  		// get spam records concurrently
   489  		go func(id flow.Identifier) {
   490  			defer wg.Done()
   491  			record, found := cache.Get(id)
   492  			require.True(t, found)
   493  			require.NotNil(t, record)
   494  		}(originID)
   495  	}
   496  
   497  	unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish")
   498  
   499  	// ensure that the records are correctly updated in the cache
   500  	for _, originID := range originIDs {
   501  		record, found := cache.Get(originID)
   502  		require.True(t, found)
   503  		require.Equal(t, -1.0, record.Penalty)
   504  	}
   505  }
   506  
   507  // TestSpamRecordCache_ConcurrentInitAndRemove tests the concurrent initialization and removal of spam records for different
   508  // origin IDs. The test covers the following scenarios:
   509  // 1. Multiple goroutines initializing spam records for different origin IDs concurrently.
   510  // 2. Multiple goroutines removing spam records for different origin IDs concurrently.
   511  // 3. The initialized records are correctly added to the cache.
   512  // 4. The removed records are correctly removed from the cache.
   513  func TestSpamRecordCache_ConcurrentInitAndRemove(t *testing.T) {
   514  	sizeLimit := uint32(100)
   515  	logger := zerolog.Nop()
   516  	collector := metrics.NewNoopCollector()
   517  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
   518  		return protocolSpamRecordFixture(id)
   519  	}
   520  	adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   521  		return record, nil // no-op
   522  	}
   523  
   524  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
   525  	require.NotNil(t, cache)
   526  
   527  	originIDs := unittest.IdentifierListFixture(20)
   528  	originIDsToAdd := originIDs[:10]
   529  	originIDsToRemove := originIDs[10:]
   530  
   531  	for _, originID := range originIDsToRemove {
   532  		penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp)
   533  		require.NoError(t, err)
   534  		require.Equal(t, float64(0), penalty)
   535  	}
   536  
   537  	var wg sync.WaitGroup
   538  	wg.Add(len(originIDs))
   539  
   540  	// initialize spam records concurrently
   541  	for _, originID := range originIDsToAdd {
   542  		originID := originID // capture range variable
   543  		go func() {
   544  			defer wg.Done()
   545  			penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp)
   546  			require.NoError(t, err)
   547  			require.Equal(t, float64(0), penalty)
   548  		}()
   549  	}
   550  
   551  	// remove spam records concurrently
   552  	for _, originID := range originIDsToRemove {
   553  		go func(id flow.Identifier) {
   554  			defer wg.Done()
   555  			cache.Remove(id)
   556  		}(originID)
   557  	}
   558  
   559  	unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish")
   560  
   561  	// ensure that the initialized records are correctly added to the cache
   562  	for _, originID := range originIDsToAdd {
   563  		record, found := cache.Get(originID)
   564  		require.True(t, found)
   565  		require.NotNil(t, record)
   566  	}
   567  
   568  	// ensure that the removed records are correctly removed from the cache
   569  	for _, originID := range originIDsToRemove {
   570  		_, found := cache.Get(originID)
   571  		require.False(t, found)
   572  	}
   573  }
   574  
   575  // TestSpamRecordCache_ConcurrentInitRemoveAdjust tests the concurrent initialization, removal, and adjustment of spam
   576  // records for different origin IDs. The test covers the following scenarios:
   577  // 1. Multiple goroutines initializing spam records for different origin IDs concurrently.
   578  // 2. Multiple goroutines removing spam records for different origin IDs concurrently.
   579  // 3. Multiple goroutines adjusting spam records for different origin IDs concurrently.
   580  func TestSpamRecordCache_ConcurrentInitRemoveAdjust(t *testing.T) {
   581  	sizeLimit := uint32(100)
   582  	logger := zerolog.Nop()
   583  	collector := metrics.NewNoopCollector()
   584  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
   585  		return protocolSpamRecordFixture(id)
   586  	}
   587  	adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   588  		return record, nil // no-op
   589  	}
   590  
   591  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
   592  	require.NotNil(t, cache)
   593  
   594  	originIDs := unittest.IdentifierListFixture(30)
   595  	originIDsToAdd := originIDs[:10]
   596  	originIDsToRemove := originIDs[10:20]
   597  	originIDsToAdjust := originIDs[20:]
   598  
   599  	for _, originID := range originIDsToRemove {
   600  		penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp)
   601  		require.NoError(t, err)
   602  		require.Equal(t, float64(0), penalty)
   603  	}
   604  
   605  	adjustFunc := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   606  		record.Penalty -= 1
   607  		return record, nil
   608  	}
   609  
   610  	var wg sync.WaitGroup
   611  	wg.Add(len(originIDs))
   612  
   613  	// Initialize spam records concurrently
   614  	for _, originID := range originIDsToAdd {
   615  		originID := originID // capture range variable
   616  		go func() {
   617  			defer wg.Done()
   618  			penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp)
   619  			require.NoError(t, err)
   620  			require.Equal(t, float64(0), penalty)
   621  		}()
   622  	}
   623  
   624  	// Remove spam records concurrently
   625  	for _, originID := range originIDsToRemove {
   626  		go func(id flow.Identifier) {
   627  			defer wg.Done()
   628  			cache.Remove(id)
   629  		}(originID)
   630  	}
   631  
   632  	// Adjust spam records concurrently
   633  	for _, originID := range originIDsToAdjust {
   634  		go func(id flow.Identifier) {
   635  			defer wg.Done()
   636  			_, _ = cache.AdjustWithInit(id, adjustFunc)
   637  		}(originID)
   638  	}
   639  
   640  	unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish")
   641  }
   642  
   643  // TestSpamRecordCache_ConcurrentInitRemoveAndAdjust tests the concurrent initialization, removal, and adjustment of spam
   644  // records for different origin IDs. The test covers the following scenarios:
   645  // 1. Multiple goroutines initializing spam records for different origin IDs concurrently.
   646  // 2. Multiple goroutines removing spam records for different origin IDs concurrently.
   647  // 3. Multiple goroutines adjusting spam records for different origin IDs concurrently.
   648  // 4. The initialized records are correctly added to the cache.
   649  // 5. The removed records are correctly removed from the cache.
   650  // 6. The adjusted records are correctly updated in the cache.
   651  func TestSpamRecordCache_ConcurrentInitRemoveAndAdjust(t *testing.T) {
   652  	sizeLimit := uint32(100)
   653  	logger := zerolog.Nop()
   654  	collector := metrics.NewNoopCollector()
   655  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
   656  		return protocolSpamRecordFixture(id)
   657  	}
   658  	adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   659  		return record, nil // no-op
   660  	}
   661  
   662  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
   663  	require.NotNil(t, cache)
   664  
   665  	originIDs := unittest.IdentifierListFixture(30)
   666  	originIDsToAdd := originIDs[:10]
   667  	originIDsToRemove := originIDs[10:20]
   668  	originIDsToAdjust := originIDs[20:]
   669  
   670  	for _, originID := range originIDsToRemove {
   671  		penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp)
   672  		require.NoError(t, err)
   673  		require.Equal(t, float64(0), penalty)
   674  	}
   675  
   676  	for _, originID := range originIDsToAdjust {
   677  		penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp)
   678  		require.NoError(t, err)
   679  		require.Equal(t, float64(0), penalty)
   680  	}
   681  
   682  	var wg sync.WaitGroup
   683  	wg.Add(len(originIDs))
   684  
   685  	// initialize spam records concurrently
   686  	for _, originID := range originIDsToAdd {
   687  		originID := originID
   688  		go func() {
   689  			defer wg.Done()
   690  			penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp)
   691  			require.NoError(t, err)
   692  			require.Equal(t, float64(0), penalty)
   693  		}()
   694  	}
   695  
   696  	// remove spam records concurrently
   697  	for _, originID := range originIDsToRemove {
   698  		originID := originID
   699  		go func() {
   700  			defer wg.Done()
   701  			cache.Remove(originID)
   702  		}()
   703  	}
   704  
   705  	// adjust spam records concurrently
   706  	for _, originID := range originIDsToAdjust {
   707  		originID := originID
   708  		go func() {
   709  			defer wg.Done()
   710  			_, err := cache.AdjustWithInit(originID, func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   711  				record.Penalty -= 1
   712  				return record, nil
   713  			})
   714  			require.NoError(t, err)
   715  		}()
   716  	}
   717  
   718  	unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish")
   719  
   720  	// ensure that the initialized records are correctly added to the cache
   721  	for _, originID := range originIDsToAdd {
   722  		record, found := cache.Get(originID)
   723  		require.True(t, found)
   724  		require.NotNil(t, record)
   725  	}
   726  
   727  	// ensure that the removed records are correctly removed from the cache
   728  	for _, originID := range originIDsToRemove {
   729  		_, found := cache.Get(originID)
   730  		require.False(t, found)
   731  	}
   732  
   733  	// ensure that the adjusted records are correctly updated in the cache
   734  	for _, originID := range originIDsToAdjust {
   735  		record, found := cache.Get(originID)
   736  		require.True(t, found)
   737  		require.NotNil(t, record)
   738  		require.Equal(t, -1.0, record.Penalty)
   739  	}
   740  }
   741  
   742  // TestSpamRecordCache_ConcurrentIdentitiesAndOperations tests the concurrent calls to Identities method while
   743  // other goroutines are initializing or removing spam records. The test covers the following scenarios:
   744  // 1. Multiple goroutines initializing spam records for different origin IDs concurrently.
   745  // 2. Multiple goroutines removing spam records for different origin IDs concurrently.
   746  // 3. Multiple goroutines calling Identities method concurrently.
   747  func TestSpamRecordCache_ConcurrentIdentitiesAndOperations(t *testing.T) {
   748  	sizeLimit := uint32(100)
   749  	logger := zerolog.Nop()
   750  	collector := metrics.NewNoopCollector()
   751  	recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord {
   752  		return protocolSpamRecordFixture(id)
   753  	}
   754  	adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) {
   755  		return record, nil // no-op
   756  	}
   757  
   758  	cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory)
   759  	require.NotNil(t, cache)
   760  
   761  	originIDs := unittest.IdentifierListFixture(20)
   762  	originIDsToAdd := originIDs[:10]
   763  	originIDsToRemove := originIDs[10:20]
   764  
   765  	for _, originID := range originIDsToRemove {
   766  		penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp)
   767  		require.NoError(t, err)
   768  		require.Equal(t, float64(0), penalty)
   769  	}
   770  
   771  	var wg sync.WaitGroup
   772  	wg.Add(len(originIDs) + 10)
   773  
   774  	// initialize spam records concurrently
   775  	for _, originID := range originIDsToAdd {
   776  		originID := originID
   777  		go func() {
   778  			defer wg.Done()
   779  			penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp)
   780  			require.NoError(t, err)
   781  			require.Equal(t, float64(0), penalty)
   782  			retrieved, ok := cache.Get(originID)
   783  			require.True(t, ok)
   784  			require.NotNil(t, retrieved)
   785  		}()
   786  	}
   787  
   788  	// remove spam records concurrently
   789  	for _, originID := range originIDsToRemove {
   790  		originID := originID
   791  		go func() {
   792  			defer wg.Done()
   793  			require.True(t, cache.Remove(originID))
   794  			retrieved, ok := cache.Get(originID)
   795  			require.False(t, ok)
   796  			require.Nil(t, retrieved)
   797  		}()
   798  	}
   799  
   800  	// call Identities method concurrently
   801  	for i := 0; i < 10; i++ {
   802  		go func() {
   803  			defer wg.Done()
   804  			ids := cache.Identities()
   805  			// the number of returned IDs should be less than or equal to the number of origin IDs
   806  			require.True(t, len(ids) <= len(originIDs))
   807  			// the returned IDs should be a subset of the origin IDs
   808  			for _, id := range ids {
   809  				require.Contains(t, originIDs, id)
   810  			}
   811  		}()
   812  	}
   813  
   814  	unittest.RequireReturnsBefore(t, wg.Wait, 1*time.Second, "timed out waiting for goroutines to finish")
   815  }