gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sync/atomicptrmap/atomicptrmap_test.go (about) 1 // Copyright 2020 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package atomicptrmap 16 17 import ( 18 "context" 19 "fmt" 20 "math/rand" 21 "reflect" 22 "runtime" 23 "testing" 24 "time" 25 26 "gvisor.dev/gvisor/pkg/sync" 27 ) 28 29 func TestConsistencyWithGoMap(t *testing.T) { 30 const maxKey = 16 31 var vals [4]*testValue 32 for i := 1; /* leave vals[0] nil */ i < len(vals); i++ { 33 vals[i] = new(testValue) 34 } 35 var ( 36 m = make(map[int64]*testValue) 37 apm testAtomicPtrMap 38 ) 39 for i := 0; i < 100000; i++ { 40 // Apply a random operation to both m and apm and expect them to have 41 // the same result. Bias toward CompareAndSwap, which has the most 42 // cases; bias away from Range and RangeRepeatable, which are 43 // relatively expensive. 44 switch rand.Intn(10) { 45 case 0, 1: // Load 46 key := rand.Int63n(maxKey) 47 want := m[key] 48 got := apm.Load(key) 49 t.Logf("Load(%d) = %p", key, got) 50 if got != want { 51 t.Fatalf("got %p, wanted %p", got, want) 52 } 53 case 2, 3: // Swap 54 key := rand.Int63n(maxKey) 55 val := vals[rand.Intn(len(vals))] 56 want := m[key] 57 if val != nil { 58 m[key] = val 59 } else { 60 delete(m, key) 61 } 62 got := apm.Swap(key, val) 63 t.Logf("Swap(%d, %p) = %p", key, val, got) 64 if got != want { 65 t.Fatalf("got %p, wanted %p", got, want) 66 } 67 case 4, 5, 6, 7: // CompareAndSwap 68 key := rand.Int63n(maxKey) 69 oldVal := vals[rand.Intn(len(vals))] 70 newVal := vals[rand.Intn(len(vals))] 71 want := m[key] 72 if want == oldVal { 73 if newVal != nil { 74 m[key] = newVal 75 } else { 76 delete(m, key) 77 } 78 } 79 got := apm.CompareAndSwap(key, oldVal, newVal) 80 t.Logf("CompareAndSwap(%d, %p, %p) = %p", key, oldVal, newVal, got) 81 if got != want { 82 t.Fatalf("got %p, wanted %p", got, want) 83 } 84 case 8: // Range 85 got := make(map[int64]*testValue) 86 var ( 87 haveDup = false 88 dup int64 89 ) 90 apm.Range(func(key int64, val *testValue) bool { 91 if _, ok := got[key]; ok && !haveDup { 92 haveDup = true 93 dup = key 94 } 95 got[key] = val 96 return true 97 }) 98 t.Logf("Range() = %v", got) 99 if !reflect.DeepEqual(got, m) { 100 t.Fatalf("got %v, wanted %v", got, m) 101 } 102 if haveDup { 103 t.Fatalf("got duplicate key %d", dup) 104 } 105 case 9: // RangeRepeatable 106 got := make(map[int64]*testValue) 107 apm.RangeRepeatable(func(key int64, val *testValue) bool { 108 got[key] = val 109 return true 110 }) 111 t.Logf("RangeRepeatable() = %v", got) 112 if !reflect.DeepEqual(got, m) { 113 t.Fatalf("got %v, wanted %v", got, m) 114 } 115 } 116 } 117 } 118 119 func TestConcurrentHeterogeneous(t *testing.T) { 120 ctx, cancel := context.WithCancel(context.Background()) 121 var ( 122 apm testAtomicPtrMap 123 wg sync.WaitGroup 124 ) 125 defer func() { 126 cancel() 127 wg.Wait() 128 }() 129 130 possibleKeyValuePairs := make(map[int64]map[*testValue]struct{}) 131 addKeyValuePair := func(key int64, val *testValue) { 132 values := possibleKeyValuePairs[key] 133 if values == nil { 134 values = make(map[*testValue]struct{}) 135 possibleKeyValuePairs[key] = values 136 } 137 values[val] = struct{}{} 138 } 139 140 const numValuesPerKey = 4 141 142 // These goroutines use keys not used by any other goroutine. 143 const numPrivateKeys = 3 144 for i := 0; i < numPrivateKeys; i++ { 145 key := int64(i) 146 var vals [numValuesPerKey]*testValue 147 for i := 1; /* leave vals[0] nil */ i < len(vals); i++ { 148 val := new(testValue) 149 vals[i] = val 150 addKeyValuePair(key, val) 151 } 152 wg.Add(1) 153 go func() { 154 defer wg.Done() 155 r := rand.New(rand.NewSource(rand.Int63())) 156 var stored *testValue 157 for ctx.Err() == nil { 158 switch r.Intn(4) { 159 case 0: 160 got := apm.Load(key) 161 if got != stored { 162 t.Errorf("Load(%d): got %p, wanted %p", key, got, stored) 163 return 164 } 165 case 1: 166 val := vals[r.Intn(len(vals))] 167 want := stored 168 stored = val 169 got := apm.Swap(key, val) 170 if got != want { 171 t.Errorf("Swap(%d, %p): got %p, wanted %p", key, val, got, want) 172 return 173 } 174 case 2, 3: 175 oldVal := vals[r.Intn(len(vals))] 176 newVal := vals[r.Intn(len(vals))] 177 want := stored 178 if stored == oldVal { 179 stored = newVal 180 } 181 got := apm.CompareAndSwap(key, oldVal, newVal) 182 if got != want { 183 t.Errorf("CompareAndSwap(%d, %p, %p): got %p, wanted %p", key, oldVal, newVal, got, want) 184 return 185 } 186 } 187 } 188 }() 189 } 190 191 // These goroutines share a small set of keys. 192 const numSharedKeys = 2 193 var ( 194 sharedKeys [numSharedKeys]int64 195 sharedValues = make(map[int64][]*testValue) 196 sharedValuesSet = make(map[int64]map[*testValue]struct{}) 197 ) 198 for i := range sharedKeys { 199 key := int64(numPrivateKeys + i) 200 sharedKeys[i] = key 201 vals := make([]*testValue, numValuesPerKey) 202 valsSet := make(map[*testValue]struct{}) 203 for j := range vals { 204 val := new(testValue) 205 vals[j] = val 206 valsSet[val] = struct{}{} 207 addKeyValuePair(key, val) 208 } 209 sharedValues[key] = vals 210 sharedValuesSet[key] = valsSet 211 } 212 randSharedValue := func(r *rand.Rand, key int64) *testValue { 213 vals := sharedValues[key] 214 return vals[r.Intn(len(vals))] 215 } 216 for i := 0; i < 3; i++ { 217 wg.Add(1) 218 go func() { 219 defer wg.Done() 220 r := rand.New(rand.NewSource(rand.Int63())) 221 for ctx.Err() == nil { 222 keyIndex := r.Intn(len(sharedKeys)) 223 key := sharedKeys[keyIndex] 224 var ( 225 op string 226 got *testValue 227 ) 228 switch r.Intn(4) { 229 case 0: 230 op = "Load" 231 got = apm.Load(key) 232 case 1: 233 op = "Swap" 234 got = apm.Swap(key, randSharedValue(r, key)) 235 case 2, 3: 236 op = "CompareAndSwap" 237 got = apm.CompareAndSwap(key, randSharedValue(r, key), randSharedValue(r, key)) 238 } 239 if got != nil { 240 valsSet := sharedValuesSet[key] 241 if _, ok := valsSet[got]; !ok { 242 t.Errorf("%s: got key %d, value %p; expected value in %v", op, key, got, valsSet) 243 return 244 } 245 } 246 } 247 }() 248 } 249 250 // This goroutine repeatedly searches for unused keys. 251 wg.Add(1) 252 go func() { 253 defer wg.Done() 254 r := rand.New(rand.NewSource(rand.Int63())) 255 for ctx.Err() == nil { 256 key := -1 - r.Int63() 257 if got := apm.Load(key); got != nil { 258 t.Errorf("Load(%d): got %p, wanted nil", key, got) 259 } 260 } 261 }() 262 263 // This goroutine repeatedly calls RangeRepeatable() and checks that each 264 // key corresponds to an expected value. 265 wg.Add(1) 266 go func() { 267 defer wg.Done() 268 abort := false 269 for !abort && ctx.Err() == nil { 270 apm.RangeRepeatable(func(key int64, val *testValue) bool { 271 values, ok := possibleKeyValuePairs[key] 272 if !ok { 273 t.Errorf("RangeRepeatable: got invalid key %d", key) 274 abort = true 275 return false 276 } 277 if _, ok := values[val]; !ok { 278 t.Errorf("RangeRepeatable: got key %d, value %p; expected one of %v", key, val, values) 279 abort = true 280 return false 281 } 282 return true 283 }) 284 } 285 }() 286 287 // Finally, the main goroutine spins for the length of the test calling 288 // Range() and checking that each key that it observes is unique and 289 // corresponds to an expected value. 290 seenKeys := make(map[int64]struct{}) 291 const testDuration = 5 * time.Second 292 end := time.Now().Add(testDuration) 293 abort := false 294 for time.Now().Before(end) { 295 apm.Range(func(key int64, val *testValue) bool { 296 values, ok := possibleKeyValuePairs[key] 297 if !ok { 298 t.Errorf("Range: got invalid key %d", key) 299 abort = true 300 return false 301 } 302 if _, ok := values[val]; !ok { 303 t.Errorf("Range: got key %d, value %p; expected one of %v", key, val, values) 304 abort = true 305 return false 306 } 307 if _, ok := seenKeys[key]; ok { 308 t.Errorf("Range: got duplicate key %d", key) 309 abort = true 310 return false 311 } 312 seenKeys[key] = struct{}{} 313 return true 314 }) 315 if abort { 316 break 317 } 318 for k := range seenKeys { 319 delete(seenKeys, k) 320 } 321 } 322 } 323 324 type benchmarkableMap interface { 325 Load(key int64) *testValue 326 Store(key int64, val *testValue) 327 LoadOrStore(key int64, val *testValue) (*testValue, bool) 328 Delete(key int64) 329 } 330 331 // rwMutexMap implements benchmarkableMap for a RWMutex-protected Go map. 332 type rwMutexMap struct { 333 mu sync.RWMutex 334 m map[int64]*testValue 335 } 336 337 func (m *rwMutexMap) Load(key int64) *testValue { 338 m.mu.RLock() 339 defer m.mu.RUnlock() 340 return m.m[key] 341 } 342 343 func (m *rwMutexMap) Store(key int64, val *testValue) { 344 m.mu.Lock() 345 defer m.mu.Unlock() 346 if m.m == nil { 347 m.m = make(map[int64]*testValue) 348 } 349 m.m[key] = val 350 } 351 352 func (m *rwMutexMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { 353 m.mu.Lock() 354 defer m.mu.Unlock() 355 if m.m == nil { 356 m.m = make(map[int64]*testValue) 357 } 358 if oldVal, ok := m.m[key]; ok { 359 return oldVal, true 360 } 361 m.m[key] = val 362 return val, false 363 } 364 365 func (m *rwMutexMap) Delete(key int64) { 366 m.mu.Lock() 367 defer m.mu.Unlock() 368 delete(m.m, key) 369 } 370 371 // syncMap implements benchmarkableMap for a sync.Map. 372 type syncMap struct { 373 m sync.Map 374 } 375 376 func (m *syncMap) Load(key int64) *testValue { 377 val, ok := m.m.Load(key) 378 if !ok { 379 return nil 380 } 381 return val.(*testValue) 382 } 383 384 func (m *syncMap) Store(key int64, val *testValue) { 385 m.m.Store(key, val) 386 } 387 388 func (m *syncMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { 389 actual, loaded := m.m.LoadOrStore(key, val) 390 return actual.(*testValue), loaded 391 } 392 393 func (m *syncMap) Delete(key int64) { 394 m.m.Delete(key) 395 } 396 397 // benchmarkableAtomicPtrMap implements benchmarkableMap for testAtomicPtrMap. 398 type benchmarkableAtomicPtrMap struct { 399 m testAtomicPtrMap 400 } 401 402 func (m *benchmarkableAtomicPtrMap) Load(key int64) *testValue { 403 return m.m.Load(key) 404 } 405 406 func (m *benchmarkableAtomicPtrMap) Store(key int64, val *testValue) { 407 m.m.Store(key, val) 408 } 409 410 func (m *benchmarkableAtomicPtrMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { 411 if prev := m.m.CompareAndSwap(key, nil, val); prev != nil { 412 return prev, true 413 } 414 return val, false 415 } 416 417 func (m *benchmarkableAtomicPtrMap) Delete(key int64) { 418 m.m.Store(key, nil) 419 } 420 421 // benchmarkableAtomicPtrMapSharded implements benchmarkableMap for testAtomicPtrMapSharded. 422 type benchmarkableAtomicPtrMapSharded struct { 423 m testAtomicPtrMapSharded 424 } 425 426 func (m *benchmarkableAtomicPtrMapSharded) Load(key int64) *testValue { 427 return m.m.Load(key) 428 } 429 430 func (m *benchmarkableAtomicPtrMapSharded) Store(key int64, val *testValue) { 431 m.m.Store(key, val) 432 } 433 434 func (m *benchmarkableAtomicPtrMapSharded) LoadOrStore(key int64, val *testValue) (*testValue, bool) { 435 if prev := m.m.CompareAndSwap(key, nil, val); prev != nil { 436 return prev, true 437 } 438 return val, false 439 } 440 441 func (m *benchmarkableAtomicPtrMapSharded) Delete(key int64) { 442 m.m.Store(key, nil) 443 } 444 445 var mapImpls = [...]struct { 446 name string 447 ctor func() benchmarkableMap 448 }{ 449 { 450 name: "RWMutexMap", 451 ctor: func() benchmarkableMap { 452 return new(rwMutexMap) 453 }, 454 }, 455 { 456 name: "SyncMap", 457 ctor: func() benchmarkableMap { 458 return new(syncMap) 459 }, 460 }, 461 { 462 name: "AtomicPtrMap", 463 ctor: func() benchmarkableMap { 464 return new(benchmarkableAtomicPtrMap) 465 }, 466 }, 467 { 468 name: "AtomicPtrMapSharded", 469 ctor: func() benchmarkableMap { 470 return new(benchmarkableAtomicPtrMapSharded) 471 }, 472 }, 473 } 474 475 func benchmarkStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) { 476 m := mapCtor() 477 val := &testValue{} 478 for i := 0; i < b.N; i++ { 479 m.Store(int64(i), val) 480 } 481 for i := 0; i < b.N; i++ { 482 m.Delete(int64(i)) 483 } 484 } 485 486 func BenchmarkStoreDelete(b *testing.B) { 487 for _, mapImpl := range mapImpls { 488 b.Run(mapImpl.name, func(b *testing.B) { 489 benchmarkStoreDelete(b, mapImpl.ctor) 490 }) 491 } 492 } 493 494 func benchmarkLoadOrStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) { 495 m := mapCtor() 496 val := &testValue{} 497 for i := 0; i < b.N; i++ { 498 m.LoadOrStore(int64(i), val) 499 } 500 for i := 0; i < b.N; i++ { 501 m.Delete(int64(i)) 502 } 503 } 504 505 func BenchmarkLoadOrStoreDelete(b *testing.B) { 506 for _, mapImpl := range mapImpls { 507 b.Run(mapImpl.name, func(b *testing.B) { 508 benchmarkLoadOrStoreDelete(b, mapImpl.ctor) 509 }) 510 } 511 } 512 513 func benchmarkLookupPositive(b *testing.B, mapCtor func() benchmarkableMap) { 514 m := mapCtor() 515 val := &testValue{} 516 for i := 0; i < b.N; i++ { 517 m.Store(int64(i), val) 518 } 519 b.ResetTimer() 520 for i := 0; i < b.N; i++ { 521 m.Load(int64(i)) 522 } 523 } 524 525 func BenchmarkLookupPositive(b *testing.B) { 526 for _, mapImpl := range mapImpls { 527 b.Run(mapImpl.name, func(b *testing.B) { 528 benchmarkLookupPositive(b, mapImpl.ctor) 529 }) 530 } 531 } 532 533 func benchmarkLookupNegative(b *testing.B, mapCtor func() benchmarkableMap) { 534 m := mapCtor() 535 val := &testValue{} 536 for i := 0; i < b.N; i++ { 537 m.Store(int64(i), val) 538 } 539 b.ResetTimer() 540 for i := 0; i < b.N; i++ { 541 m.Load(int64(-1 - i)) 542 } 543 } 544 545 func BenchmarkLookupNegative(b *testing.B) { 546 for _, mapImpl := range mapImpls { 547 b.Run(mapImpl.name, func(b *testing.B) { 548 benchmarkLookupNegative(b, mapImpl.ctor) 549 }) 550 } 551 } 552 553 type benchmarkConcurrentOptions struct { 554 // loadsPerMutationPair is the number of map lookups between each 555 // insertion/deletion pair. 556 loadsPerMutationPair int 557 558 // If changeKeys is true, the keys used by each goroutine change between 559 // iterations of the test. 560 changeKeys bool 561 } 562 563 func benchmarkConcurrent(b *testing.B, mapCtor func() benchmarkableMap, opts benchmarkConcurrentOptions) { 564 var ( 565 started sync.WaitGroup 566 workers sync.WaitGroup 567 ) 568 started.Add(1) 569 570 m := mapCtor() 571 val := &testValue{} 572 // Insert a large number of unused elements into the map so that used 573 // elements are distributed throughout memory. 574 for i := 0; i < 10000; i++ { 575 m.Store(int64(-1-i), val) 576 } 577 // n := ceil(b.N / (opts.loadsPerMutationPair + 2)) 578 n := (b.N + opts.loadsPerMutationPair + 1) / (opts.loadsPerMutationPair + 2) 579 for i, procs := 0, runtime.GOMAXPROCS(0); i < procs; i++ { 580 workerID := i 581 workers.Add(1) 582 go func() { 583 defer workers.Done() 584 started.Wait() 585 for i := 0; i < n; i++ { 586 var key int64 587 if opts.changeKeys { 588 key = int64(workerID*n + i) 589 } else { 590 key = int64(workerID) 591 } 592 m.LoadOrStore(key, val) 593 for j := 0; j < opts.loadsPerMutationPair; j++ { 594 m.Load(key) 595 } 596 m.Delete(key) 597 } 598 }() 599 } 600 601 b.ResetTimer() 602 started.Done() 603 workers.Wait() 604 } 605 606 func BenchmarkConcurrent(b *testing.B) { 607 changeKeysChoices := [...]struct { 608 name string 609 val bool 610 }{ 611 {"FixedKeys", false}, 612 {"ChangingKeys", true}, 613 } 614 writePcts := [...]struct { 615 name string 616 loadsPerMutationPair int 617 }{ 618 {"1PercentWrites", 198}, 619 {"10PercentWrites", 18}, 620 {"50PercentWrites", 2}, 621 } 622 for _, changeKeys := range changeKeysChoices { 623 for _, writePct := range writePcts { 624 for _, mapImpl := range mapImpls { 625 name := fmt.Sprintf("%s_%s_%s", changeKeys.name, writePct.name, mapImpl.name) 626 b.Run(name, func(b *testing.B) { 627 benchmarkConcurrent(b, mapImpl.ctor, benchmarkConcurrentOptions{ 628 loadsPerMutationPair: writePct.loadsPerMutationPair, 629 changeKeys: changeKeys.val, 630 }) 631 }) 632 } 633 } 634 } 635 }