github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/network/alsp/internal/cache_test.go (about) 1 package internal_test 2 3 import ( 4 "errors" 5 "sync" 6 "testing" 7 "time" 8 9 "github.com/rs/zerolog" 10 "github.com/stretchr/testify/require" 11 12 "github.com/onflow/flow-go/model/flow" 13 "github.com/onflow/flow-go/module/metrics" 14 "github.com/onflow/flow-go/network/alsp/internal" 15 "github.com/onflow/flow-go/network/alsp/model" 16 "github.com/onflow/flow-go/utils/unittest" 17 ) 18 19 // TestNewSpamRecordCache tests the creation of a new SpamRecordCache. 20 // It ensures that the returned cache is not nil. It does not test the 21 // functionality of the cache. 22 func TestNewSpamRecordCache(t *testing.T) { 23 sizeLimit := uint32(100) 24 logger := zerolog.Nop() 25 collector := metrics.NewNoopCollector() 26 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 27 return protocolSpamRecordFixture(id) 28 } 29 30 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 31 require.NotNil(t, cache) 32 require.Equalf(t, uint(0), cache.Size(), "cache size must be 0") 33 } 34 35 // protocolSpamRecordFixture creates a new protocol spam record with the given origin id. 36 // Args: 37 // - id: the origin id of the spam record. 38 // Returns: 39 // - alsp.ProtocolSpamRecord, the created spam record. 40 // Note that the returned spam record is not a valid spam record. It is used only for testing. 41 func protocolSpamRecordFixture(id flow.Identifier) model.ProtocolSpamRecord { 42 return model.ProtocolSpamRecord{ 43 OriginId: id, 44 Decay: 1000, 45 CutoffCounter: 0, 46 Penalty: 0, 47 } 48 } 49 50 // TestSpamRecordCache_Adjust_Init tests that when the Adjust function is called 51 // on a record that does not exist in the cache, the record is initialized and 52 // the adjust function is applied to the initialized record. 53 func TestSpamRecordCache_Adjust_Init(t *testing.T) { 54 sizeLimit := uint32(100) 55 logger := zerolog.Nop() 56 collector := metrics.NewNoopCollector() 57 58 recordFactoryCalled := 0 59 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 60 require.Less(t, recordFactoryCalled, 2, "record factory must be called only twice") 61 return protocolSpamRecordFixture(id) 62 } 63 adjustFuncIncrement := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 64 record.Penalty += 1 65 return record, nil 66 } 67 68 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 69 require.NotNil(t, cache) 70 require.Zerof(t, cache.Size(), "expected cache to be empty") 71 72 originID1 := unittest.IdentifierFixture() 73 originID2 := unittest.IdentifierFixture() 74 75 // adjusting a spam record for an origin ID that does not exist in the cache should initialize the record. 76 initializedPenalty, err := cache.AdjustWithInit(originID1, adjustFuncIncrement) 77 require.NoError(t, err, "expected no error") 78 require.Equal(t, float64(1), initializedPenalty, "expected initialized penalty to be 1") 79 80 record1, ok := cache.Get(originID1) 81 require.True(t, ok, "expected record to exist") 82 require.NotNil(t, record1, "expected non-nil record") 83 require.Equal(t, originID1, record1.OriginId, "expected record to have correct origin ID") 84 require.False(t, record1.DisallowListed, "expected record to not be disallow listed") 85 require.Equal(t, cache.Size(), uint(1), "expected cache to have one record") 86 87 // adjusting a spam record for an origin ID that already exists in the cache should not initialize the record, 88 // but should apply the adjust function to the existing record. 89 initializedPenalty, err = cache.AdjustWithInit(originID1, adjustFuncIncrement) 90 require.NoError(t, err, "expected no error") 91 require.Equal(t, float64(2), initializedPenalty, "expected initialized penalty to be 2") 92 record1Again, ok := cache.Get(originID1) 93 require.True(t, ok, "expected record to still exist") 94 require.NotNil(t, record1Again, "expected non-nil record") 95 require.Equal(t, originID1, record1Again.OriginId, "expected record to have correct origin ID") 96 require.False(t, record1Again.DisallowListed, "expected record not to be disallow listed") 97 require.Equal(t, cache.Size(), uint(1), "expected cache to still have one record") 98 99 // adjusting a spam record for a different origin ID should initialize the record. 100 // this is to ensure that the record factory is called only once. 101 initializedPenalty, err = cache.AdjustWithInit(originID2, adjustFuncIncrement) 102 require.NoError(t, err, "expected no error") 103 require.Equal(t, float64(1), initializedPenalty, "expected initialized penalty to be 1") 104 record2, ok := cache.Get(originID2) 105 require.True(t, ok, "expected record to exist") 106 require.NotNil(t, record2, "expected non-nil record") 107 require.Equal(t, originID2, record2.OriginId, "expected record to have correct origin ID") 108 require.False(t, record2.DisallowListed, "expected record not to be disallow listed") 109 require.Equal(t, cache.Size(), uint(2), "expected cache to have two records") 110 } 111 112 // TestSpamRecordCache_Adjust tests the Adjust method of the SpamRecordCache. 113 // The test covers the following scenarios: 114 // 1. Adjusting a spam record for an existing origin ID. 115 // 2. Attempting to adjust a spam record with an adjustFunc that returns an error. 116 func TestSpamRecordCache_Adjust_Error(t *testing.T) { 117 sizeLimit := uint32(100) 118 logger := zerolog.Nop() 119 collector := metrics.NewNoopCollector() 120 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 121 return protocolSpamRecordFixture(id) 122 } 123 adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 124 return record, nil // no-op 125 } 126 127 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 128 require.NotNil(t, cache) 129 130 originID1 := unittest.IdentifierFixture() 131 originID2 := unittest.IdentifierFixture() 132 133 // initialize spam records for originID1 and originID2 134 penalty, err := cache.AdjustWithInit(originID1, adjustFnNoOp) 135 require.NoError(t, err, "expected no error") 136 require.Equal(t, 0.0, penalty, "expected penalty to be 0") 137 penalty, err = cache.AdjustWithInit(originID2, adjustFnNoOp) 138 require.NoError(t, err, "expected no error") 139 require.Equal(t, 0.0, penalty, "expected penalty to be 0") 140 141 // test adjusting the spam record for an existing origin ID 142 adjustFunc := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 143 record.Penalty -= 10 144 return record, nil 145 } 146 penalty, err = cache.AdjustWithInit(originID1, adjustFunc) 147 require.NoError(t, err) 148 require.Equal(t, -10.0, penalty) 149 150 record1, ok := cache.Get(originID1) 151 require.True(t, ok) 152 require.NotNil(t, record1) 153 require.Equal(t, -10.0, record1.Penalty) 154 155 // test adjusting the spam record with an adjustFunc that returns an error 156 adjustFuncError := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 157 return record, errors.New("adjustment error") 158 } 159 _, err = cache.AdjustWithInit(originID1, adjustFuncError) 160 require.Error(t, err) 161 162 // even though the adjustFunc returned an error, the record should be intact. 163 record1, ok = cache.Get(originID1) 164 require.True(t, ok) 165 require.NotNil(t, record1) 166 require.Equal(t, -10.0, record1.Penalty) 167 } 168 169 // TestSpamRecordCache_Identities tests the Identities method of the SpamRecordCache. 170 // The test covers the following scenarios: 171 // 1. Initializing the cache with multiple spam records. 172 // 2. Checking if the Identities method returns the correct set of origin IDs. 173 func TestSpamRecordCache_Identities(t *testing.T) { 174 sizeLimit := uint32(100) 175 logger := zerolog.Nop() 176 collector := metrics.NewNoopCollector() 177 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 178 return protocolSpamRecordFixture(id) 179 } 180 adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 181 return record, nil // no-op 182 } 183 184 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 185 require.NotNil(t, cache) 186 187 originID1 := unittest.IdentifierFixture() 188 originID2 := unittest.IdentifierFixture() 189 originID3 := unittest.IdentifierFixture() 190 191 // initialize spam records for a few origin IDs 192 _, err := cache.AdjustWithInit(originID1, adjustFnNoOp) 193 require.NoError(t, err) 194 _, err = cache.AdjustWithInit(originID2, adjustFnNoOp) 195 require.NoError(t, err) 196 _, err = cache.AdjustWithInit(originID3, adjustFnNoOp) 197 require.NoError(t, err) 198 199 // check if the Identities method returns the correct set of origin IDs 200 identities := cache.Identities() 201 require.Equal(t, 3, len(identities)) 202 203 identityMap := make(map[flow.Identifier]struct{}) 204 for _, id := range identities { 205 identityMap[id] = struct{}{} 206 } 207 208 require.Contains(t, identityMap, originID1) 209 require.Contains(t, identityMap, originID2) 210 require.Contains(t, identityMap, originID3) 211 } 212 213 // TestSpamRecordCache_Remove tests the Remove method of the SpamRecordCache. 214 // The test covers the following scenarios: 215 // 1. Initializing the cache with multiple spam records. 216 // 2. Removing a spam record and checking if it is removed correctly. 217 // 3. Ensuring the other spam records are still in the cache after removal. 218 // 4. Attempting to remove a non-existent origin ID. 219 func TestSpamRecordCache_Remove(t *testing.T) { 220 sizeLimit := uint32(100) 221 logger := zerolog.Nop() 222 collector := metrics.NewNoopCollector() 223 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 224 return protocolSpamRecordFixture(id) 225 } 226 adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 227 return record, nil // no-op 228 } 229 230 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 231 require.NotNil(t, cache) 232 233 originID1 := unittest.IdentifierFixture() 234 originID2 := unittest.IdentifierFixture() 235 originID3 := unittest.IdentifierFixture() 236 237 // initialize spam records for a few origin IDs 238 _, err := cache.AdjustWithInit(originID1, adjustFnNoOp) 239 require.NoError(t, err) 240 _, err = cache.AdjustWithInit(originID2, adjustFnNoOp) 241 require.NoError(t, err) 242 _, err = cache.AdjustWithInit(originID3, adjustFnNoOp) 243 require.NoError(t, err) 244 245 // remove originID1 and check if the record is removed 246 require.True(t, cache.Remove(originID1)) 247 _, exists := cache.Get(originID1) 248 require.False(t, exists) 249 250 // check if the other origin IDs are still in the cache 251 _, exists = cache.Get(originID2) 252 require.True(t, exists) 253 _, exists = cache.Get(originID3) 254 require.True(t, exists) 255 256 // attempt to remove a non-existent origin ID 257 originID4 := unittest.IdentifierFixture() 258 require.False(t, cache.Remove(originID4)) 259 } 260 261 // TestSpamRecordCache_EdgeCasesAndInvalidInputs tests the edge cases and invalid inputs for SpamRecordCache methods. 262 // The test covers the following scenarios: 263 // 1. Initializing a spam record multiple times. 264 // 2. Adjusting a non-existent spam record. 265 // 3. Removing a spam record multiple times. 266 func TestSpamRecordCache_EdgeCasesAndInvalidInputs(t *testing.T) { 267 sizeLimit := uint32(100) 268 logger := zerolog.Nop() 269 collector := metrics.NewNoopCollector() 270 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 271 return protocolSpamRecordFixture(id) 272 } 273 adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 274 return record, nil // no-op 275 } 276 277 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 278 require.NotNil(t, cache) 279 280 // 1. initializing a spam record multiple times 281 originID1 := unittest.IdentifierFixture() 282 283 _, err := cache.AdjustWithInit(originID1, adjustFnNoOp) 284 require.NoError(t, err) 285 _, err = cache.AdjustWithInit(originID1, adjustFnNoOp) 286 require.NoError(t, err) 287 288 // 2. Test adjusting a non-existent spam record 289 originID2 := unittest.IdentifierFixture() 290 initialPenalty, err := cache.AdjustWithInit(originID2, func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 291 record.Penalty -= 10 292 return record, nil 293 }) 294 require.NoError(t, err) 295 require.Equal(t, float64(-10), initialPenalty) 296 297 // 3. Test removing a spam record multiple times 298 originID3 := unittest.IdentifierFixture() 299 _, err = cache.AdjustWithInit(originID3, adjustFnNoOp) 300 require.NoError(t, err) 301 require.True(t, cache.Remove(originID3)) 302 require.False(t, cache.Remove(originID3)) 303 } 304 305 // TestSpamRecordCache_ConcurrentInitialization tests the concurrent initialization of spam records. 306 // The test covers the following scenarios: 307 // 1. Multiple goroutines initializing spam records for different origin IDs. 308 // 2. Ensuring that all spam records are correctly initialized. 309 func TestSpamRecordCache_ConcurrentInitialization(t *testing.T) { 310 sizeLimit := uint32(100) 311 logger := zerolog.Nop() 312 collector := metrics.NewNoopCollector() 313 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 314 return protocolSpamRecordFixture(id) 315 } 316 adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 317 return record, nil // no-op 318 } 319 320 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 321 require.NotNil(t, cache) 322 323 originIDs := unittest.IdentifierListFixture(10) 324 325 var wg sync.WaitGroup 326 wg.Add(len(originIDs)) 327 328 for _, originID := range originIDs { 329 go func(id flow.Identifier) { 330 defer wg.Done() 331 penalty, err := cache.AdjustWithInit(id, adjustFnNoOp) 332 require.NoError(t, err) 333 require.Equal(t, float64(0), penalty) 334 }(originID) 335 } 336 337 unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish") 338 339 // ensure that all spam records are correctly initialized 340 for _, originID := range originIDs { 341 record, found := cache.Get(originID) 342 require.True(t, found) 343 require.NotNil(t, record) 344 require.Equal(t, originID, record.OriginId) 345 } 346 } 347 348 // TestSpamRecordCache_ConcurrentSameRecordAdjust tests the concurrent adjust of the same spam record. 349 // The test covers the following scenarios: 350 // 1. Multiple goroutines attempting to adjust the same spam record concurrently. 351 // 2. Only one of the adjust operations succeeds on initializing the record. 352 // 3. The rest of the adjust operations only update the record (no initialization). 353 func TestSpamRecordCache_ConcurrentSameRecordAdjust(t *testing.T) { 354 sizeLimit := uint32(100) 355 logger := zerolog.Nop() 356 collector := metrics.NewNoopCollector() 357 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 358 return protocolSpamRecordFixture(id) 359 } 360 adjustFn := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 361 record.Penalty -= 1.0 362 record.DisallowListed = true 363 record.Decay += 1.0 364 return record, nil // no-op 365 } 366 367 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 368 require.NotNil(t, cache) 369 370 originID := unittest.IdentifierFixture() 371 const concurrentAttempts = 10 372 373 var wg sync.WaitGroup 374 wg.Add(concurrentAttempts) 375 376 for i := 0; i < concurrentAttempts; i++ { 377 go func() { 378 defer wg.Done() 379 penalty, err := cache.AdjustWithInit(originID, adjustFn) 380 require.NoError(t, err) 381 require.Less(t, penalty, 0.0) // penalty should be negative 382 }() 383 } 384 385 unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish") 386 387 // ensure that the record is correctly initialized and adjusted in the cache 388 initDecay := model.SpamRecordFactory()(originID).Decay 389 record, found := cache.Get(originID) 390 require.True(t, found) 391 require.NotNil(t, record) 392 require.Equal(t, concurrentAttempts*-1.0, record.Penalty) 393 require.Equal(t, initDecay+concurrentAttempts*1.0, record.Decay) 394 require.True(t, record.DisallowListed) 395 require.Equal(t, originID, record.OriginId) 396 } 397 398 // TestSpamRecordCache_ConcurrentRemoval tests the concurrent removal of spam records for different origin IDs. 399 // The test covers the following scenarios: 400 // 1. Multiple goroutines removing spam records for different origin IDs concurrently. 401 // 2. The records are correctly removed from the cache. 402 func TestSpamRecordCache_ConcurrentRemoval(t *testing.T) { 403 sizeLimit := uint32(100) 404 logger := zerolog.Nop() 405 collector := metrics.NewNoopCollector() 406 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 407 return protocolSpamRecordFixture(id) 408 } 409 adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 410 return record, nil // no-op 411 } 412 413 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 414 require.NotNil(t, cache) 415 416 originIDs := unittest.IdentifierListFixture(10) 417 for _, originID := range originIDs { 418 penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp) 419 require.NoError(t, err) 420 require.Equal(t, float64(0), penalty) 421 } 422 423 var wg sync.WaitGroup 424 wg.Add(len(originIDs)) 425 426 for _, originID := range originIDs { 427 go func(id flow.Identifier) { 428 defer wg.Done() 429 removed := cache.Remove(id) 430 require.True(t, removed) 431 }(originID) 432 } 433 434 unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish") 435 436 // ensure that the records are correctly removed from the cache 437 for _, originID := range originIDs { 438 _, found := cache.Get(originID) 439 require.False(t, found) 440 } 441 442 // ensure that the cache is empty 443 require.Equal(t, uint(0), cache.Size()) 444 } 445 446 // TestSpamRecordCache_ConcurrentUpdatesAndReads tests the concurrent adjustments and reads of spam records for different 447 // origin IDs. The test covers the following scenarios: 448 // 1. Multiple goroutines adjusting spam records for different origin IDs concurrently. 449 // 2. Multiple goroutines getting spam records for different origin IDs concurrently. 450 // 3. The adjusted records are correctly updated in the cache. 451 func TestSpamRecordCache_ConcurrentUpdatesAndReads(t *testing.T) { 452 sizeLimit := uint32(100) 453 logger := zerolog.Nop() 454 collector := metrics.NewNoopCollector() 455 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 456 return protocolSpamRecordFixture(id) 457 } 458 adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 459 return record, nil // no-op 460 } 461 462 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 463 require.NotNil(t, cache) 464 465 originIDs := unittest.IdentifierListFixture(10) 466 for _, originID := range originIDs { 467 penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp) 468 require.NoError(t, err) 469 require.Equal(t, float64(0), penalty) 470 } 471 472 var wg sync.WaitGroup 473 wg.Add(len(originIDs) * 2) 474 475 adjustFunc := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 476 record.Penalty -= 1 477 return record, nil 478 } 479 480 for _, originID := range originIDs { 481 // adjust spam records concurrently 482 go func(id flow.Identifier) { 483 defer wg.Done() 484 _, err := cache.AdjustWithInit(id, adjustFunc) 485 require.NoError(t, err) 486 }(originID) 487 488 // get spam records concurrently 489 go func(id flow.Identifier) { 490 defer wg.Done() 491 record, found := cache.Get(id) 492 require.True(t, found) 493 require.NotNil(t, record) 494 }(originID) 495 } 496 497 unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish") 498 499 // ensure that the records are correctly updated in the cache 500 for _, originID := range originIDs { 501 record, found := cache.Get(originID) 502 require.True(t, found) 503 require.Equal(t, -1.0, record.Penalty) 504 } 505 } 506 507 // TestSpamRecordCache_ConcurrentInitAndRemove tests the concurrent initialization and removal of spam records for different 508 // origin IDs. The test covers the following scenarios: 509 // 1. Multiple goroutines initializing spam records for different origin IDs concurrently. 510 // 2. Multiple goroutines removing spam records for different origin IDs concurrently. 511 // 3. The initialized records are correctly added to the cache. 512 // 4. The removed records are correctly removed from the cache. 513 func TestSpamRecordCache_ConcurrentInitAndRemove(t *testing.T) { 514 sizeLimit := uint32(100) 515 logger := zerolog.Nop() 516 collector := metrics.NewNoopCollector() 517 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 518 return protocolSpamRecordFixture(id) 519 } 520 adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 521 return record, nil // no-op 522 } 523 524 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 525 require.NotNil(t, cache) 526 527 originIDs := unittest.IdentifierListFixture(20) 528 originIDsToAdd := originIDs[:10] 529 originIDsToRemove := originIDs[10:] 530 531 for _, originID := range originIDsToRemove { 532 penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp) 533 require.NoError(t, err) 534 require.Equal(t, float64(0), penalty) 535 } 536 537 var wg sync.WaitGroup 538 wg.Add(len(originIDs)) 539 540 // initialize spam records concurrently 541 for _, originID := range originIDsToAdd { 542 originID := originID // capture range variable 543 go func() { 544 defer wg.Done() 545 penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp) 546 require.NoError(t, err) 547 require.Equal(t, float64(0), penalty) 548 }() 549 } 550 551 // remove spam records concurrently 552 for _, originID := range originIDsToRemove { 553 go func(id flow.Identifier) { 554 defer wg.Done() 555 cache.Remove(id) 556 }(originID) 557 } 558 559 unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish") 560 561 // ensure that the initialized records are correctly added to the cache 562 for _, originID := range originIDsToAdd { 563 record, found := cache.Get(originID) 564 require.True(t, found) 565 require.NotNil(t, record) 566 } 567 568 // ensure that the removed records are correctly removed from the cache 569 for _, originID := range originIDsToRemove { 570 _, found := cache.Get(originID) 571 require.False(t, found) 572 } 573 } 574 575 // TestSpamRecordCache_ConcurrentInitRemoveAdjust tests the concurrent initialization, removal, and adjustment of spam 576 // records for different origin IDs. The test covers the following scenarios: 577 // 1. Multiple goroutines initializing spam records for different origin IDs concurrently. 578 // 2. Multiple goroutines removing spam records for different origin IDs concurrently. 579 // 3. Multiple goroutines adjusting spam records for different origin IDs concurrently. 580 func TestSpamRecordCache_ConcurrentInitRemoveAdjust(t *testing.T) { 581 sizeLimit := uint32(100) 582 logger := zerolog.Nop() 583 collector := metrics.NewNoopCollector() 584 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 585 return protocolSpamRecordFixture(id) 586 } 587 adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 588 return record, nil // no-op 589 } 590 591 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 592 require.NotNil(t, cache) 593 594 originIDs := unittest.IdentifierListFixture(30) 595 originIDsToAdd := originIDs[:10] 596 originIDsToRemove := originIDs[10:20] 597 originIDsToAdjust := originIDs[20:] 598 599 for _, originID := range originIDsToRemove { 600 penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp) 601 require.NoError(t, err) 602 require.Equal(t, float64(0), penalty) 603 } 604 605 adjustFunc := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 606 record.Penalty -= 1 607 return record, nil 608 } 609 610 var wg sync.WaitGroup 611 wg.Add(len(originIDs)) 612 613 // Initialize spam records concurrently 614 for _, originID := range originIDsToAdd { 615 originID := originID // capture range variable 616 go func() { 617 defer wg.Done() 618 penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp) 619 require.NoError(t, err) 620 require.Equal(t, float64(0), penalty) 621 }() 622 } 623 624 // Remove spam records concurrently 625 for _, originID := range originIDsToRemove { 626 go func(id flow.Identifier) { 627 defer wg.Done() 628 cache.Remove(id) 629 }(originID) 630 } 631 632 // Adjust spam records concurrently 633 for _, originID := range originIDsToAdjust { 634 go func(id flow.Identifier) { 635 defer wg.Done() 636 _, _ = cache.AdjustWithInit(id, adjustFunc) 637 }(originID) 638 } 639 640 unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish") 641 } 642 643 // TestSpamRecordCache_ConcurrentInitRemoveAndAdjust tests the concurrent initialization, removal, and adjustment of spam 644 // records for different origin IDs. The test covers the following scenarios: 645 // 1. Multiple goroutines initializing spam records for different origin IDs concurrently. 646 // 2. Multiple goroutines removing spam records for different origin IDs concurrently. 647 // 3. Multiple goroutines adjusting spam records for different origin IDs concurrently. 648 // 4. The initialized records are correctly added to the cache. 649 // 5. The removed records are correctly removed from the cache. 650 // 6. The adjusted records are correctly updated in the cache. 651 func TestSpamRecordCache_ConcurrentInitRemoveAndAdjust(t *testing.T) { 652 sizeLimit := uint32(100) 653 logger := zerolog.Nop() 654 collector := metrics.NewNoopCollector() 655 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 656 return protocolSpamRecordFixture(id) 657 } 658 adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 659 return record, nil // no-op 660 } 661 662 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 663 require.NotNil(t, cache) 664 665 originIDs := unittest.IdentifierListFixture(30) 666 originIDsToAdd := originIDs[:10] 667 originIDsToRemove := originIDs[10:20] 668 originIDsToAdjust := originIDs[20:] 669 670 for _, originID := range originIDsToRemove { 671 penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp) 672 require.NoError(t, err) 673 require.Equal(t, float64(0), penalty) 674 } 675 676 for _, originID := range originIDsToAdjust { 677 penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp) 678 require.NoError(t, err) 679 require.Equal(t, float64(0), penalty) 680 } 681 682 var wg sync.WaitGroup 683 wg.Add(len(originIDs)) 684 685 // initialize spam records concurrently 686 for _, originID := range originIDsToAdd { 687 originID := originID 688 go func() { 689 defer wg.Done() 690 penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp) 691 require.NoError(t, err) 692 require.Equal(t, float64(0), penalty) 693 }() 694 } 695 696 // remove spam records concurrently 697 for _, originID := range originIDsToRemove { 698 originID := originID 699 go func() { 700 defer wg.Done() 701 cache.Remove(originID) 702 }() 703 } 704 705 // adjust spam records concurrently 706 for _, originID := range originIDsToAdjust { 707 originID := originID 708 go func() { 709 defer wg.Done() 710 _, err := cache.AdjustWithInit(originID, func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 711 record.Penalty -= 1 712 return record, nil 713 }) 714 require.NoError(t, err) 715 }() 716 } 717 718 unittest.RequireReturnsBefore(t, wg.Wait, 100*time.Millisecond, "timed out waiting for goroutines to finish") 719 720 // ensure that the initialized records are correctly added to the cache 721 for _, originID := range originIDsToAdd { 722 record, found := cache.Get(originID) 723 require.True(t, found) 724 require.NotNil(t, record) 725 } 726 727 // ensure that the removed records are correctly removed from the cache 728 for _, originID := range originIDsToRemove { 729 _, found := cache.Get(originID) 730 require.False(t, found) 731 } 732 733 // ensure that the adjusted records are correctly updated in the cache 734 for _, originID := range originIDsToAdjust { 735 record, found := cache.Get(originID) 736 require.True(t, found) 737 require.NotNil(t, record) 738 require.Equal(t, -1.0, record.Penalty) 739 } 740 } 741 742 // TestSpamRecordCache_ConcurrentIdentitiesAndOperations tests the concurrent calls to Identities method while 743 // other goroutines are initializing or removing spam records. The test covers the following scenarios: 744 // 1. Multiple goroutines initializing spam records for different origin IDs concurrently. 745 // 2. Multiple goroutines removing spam records for different origin IDs concurrently. 746 // 3. Multiple goroutines calling Identities method concurrently. 747 func TestSpamRecordCache_ConcurrentIdentitiesAndOperations(t *testing.T) { 748 sizeLimit := uint32(100) 749 logger := zerolog.Nop() 750 collector := metrics.NewNoopCollector() 751 recordFactory := func(id flow.Identifier) model.ProtocolSpamRecord { 752 return protocolSpamRecordFixture(id) 753 } 754 adjustFnNoOp := func(record model.ProtocolSpamRecord) (model.ProtocolSpamRecord, error) { 755 return record, nil // no-op 756 } 757 758 cache := internal.NewSpamRecordCache(sizeLimit, logger, collector, recordFactory) 759 require.NotNil(t, cache) 760 761 originIDs := unittest.IdentifierListFixture(20) 762 originIDsToAdd := originIDs[:10] 763 originIDsToRemove := originIDs[10:20] 764 765 for _, originID := range originIDsToRemove { 766 penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp) 767 require.NoError(t, err) 768 require.Equal(t, float64(0), penalty) 769 } 770 771 var wg sync.WaitGroup 772 wg.Add(len(originIDs) + 10) 773 774 // initialize spam records concurrently 775 for _, originID := range originIDsToAdd { 776 originID := originID 777 go func() { 778 defer wg.Done() 779 penalty, err := cache.AdjustWithInit(originID, adjustFnNoOp) 780 require.NoError(t, err) 781 require.Equal(t, float64(0), penalty) 782 retrieved, ok := cache.Get(originID) 783 require.True(t, ok) 784 require.NotNil(t, retrieved) 785 }() 786 } 787 788 // remove spam records concurrently 789 for _, originID := range originIDsToRemove { 790 originID := originID 791 go func() { 792 defer wg.Done() 793 require.True(t, cache.Remove(originID)) 794 retrieved, ok := cache.Get(originID) 795 require.False(t, ok) 796 require.Nil(t, retrieved) 797 }() 798 } 799 800 // call Identities method concurrently 801 for i := 0; i < 10; i++ { 802 go func() { 803 defer wg.Done() 804 ids := cache.Identities() 805 // the number of returned IDs should be less than or equal to the number of origin IDs 806 require.True(t, len(ids) <= len(originIDs)) 807 // the returned IDs should be a subset of the origin IDs 808 for _, id := range ids { 809 require.Contains(t, originIDs, id) 810 } 811 }() 812 } 813 814 unittest.RequireReturnsBefore(t, wg.Wait, 1*time.Second, "timed out waiting for goroutines to finish") 815 }