github.com/timandy/routine@v1.1.4-0.20240507073150-e4a3e1fe2ba5/api_thread_local_test.go (about) 1 package routine 2 3 import ( 4 "math/rand" 5 "sync" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 ) 10 11 const ( 12 concurrency = 500 13 loopTimes = 200 14 ) 15 16 func TestSupplier(t *testing.T) { 17 var supplier Supplier[string] = func() string { 18 return "Hello" 19 } 20 assert.Equal(t, "Hello", supplier()) 21 // 22 var fun func() string = supplier 23 assert.Equal(t, "Hello", fun()) 24 } 25 26 //=== 27 28 func TestNewThreadLocal_Single(t *testing.T) { 29 tls := NewThreadLocal[string]() 30 tls.Set("Hello") 31 assert.Equal(t, "Hello", tls.Get()) 32 // 33 tls2 := NewThreadLocal[int]() 34 assert.Equal(t, "Hello", tls.Get()) 35 tls2.Set(22) 36 assert.Equal(t, 22, tls2.Get()) 37 // 38 tls2.Set(33) 39 assert.Equal(t, 33, tls2.Get()) 40 // 41 task := GoWait(func(token CancelToken) { 42 assert.Equal(t, "", tls.Get()) 43 assert.Equal(t, 0, tls2.Get()) 44 }) 45 task.Get() 46 } 47 48 func TestNewThreadLocal_Multi(t *testing.T) { 49 tls := NewThreadLocal[string]() 50 tls2 := NewThreadLocal[int]() 51 tls.Set("Hello") 52 tls2.Set(22) 53 assert.Equal(t, 22, tls2.Get()) 54 assert.Equal(t, "Hello", tls.Get()) 55 // 56 tls2.Set(33) 57 assert.Equal(t, 33, tls2.Get()) 58 // 59 task := GoWait(func(token CancelToken) { 60 assert.Equal(t, "", tls.Get()) 61 assert.Equal(t, 0, tls2.Get()) 62 }) 63 task.Get() 64 } 65 66 func TestNewThreadLocal_Concurrency(t *testing.T) { 67 tls := NewThreadLocal[uint64]() 68 tls2 := NewThreadLocal[uint64]() 69 // 70 tls2.Set(33) 71 assert.Equal(t, uint64(33), tls2.Get()) 72 // 73 wg := &sync.WaitGroup{} 74 wg.Add(concurrency) 75 for i := 0; i < concurrency; i++ { 76 assert.Equal(t, uint64(0), tls.Get()) 77 assert.Equal(t, uint64(33), tls2.Get()) 78 Go(func() { 79 assert.Equal(t, uint64(0), tls.Get()) 80 assert.Equal(t, uint64(0), tls2.Get()) 81 v := rand.Uint64() 82 v2 := rand.Uint64() 83 for j := 0; j < loopTimes; j++ { 84 tls.Set(v) 85 tmp := tls.Get() 86 assert.Equal(t, v, tmp) 87 // 88 tls2.Set(v2) 89 tmp2 := tls2.Get() 90 assert.Equal(t, v2, tmp2) 91 } 92 wg.Done() 93 }) 94 } 95 wg.Wait() 96 // 97 task := GoWait(func(token CancelToken) { 98 assert.Equal(t, uint64(0), tls.Get()) 99 assert.Equal(t, uint64(0), tls2.Get()) 100 }) 101 task.Get() 102 } 103 104 func TestNewThreadLocal_Interface(t *testing.T) { 105 tls := NewThreadLocal[Cloneable]() 106 tls2 := NewThreadLocal[Cloneable]() 107 // 108 assert.Nil(t, tls.Get()) 109 assert.Nil(t, tls2.Get()) 110 // 111 tls.Set(nil) 112 tls2.Set(nil) 113 assert.Nil(t, tls.Get()) 114 assert.Nil(t, tls2.Get()) 115 // 116 tls.Set(&personCloneable{Id: 1, Name: "Hello"}) 117 tls2.Set(&personCloneable{Id: 1, Name: "Hello"}) 118 assert.NotNil(t, tls.Get()) 119 assert.NotNil(t, tls2.Get()) 120 // 121 tls.Remove() 122 tls2.Remove() 123 assert.Nil(t, tls.Get()) 124 assert.Nil(t, tls2.Get()) 125 } 126 127 func TestNewThreadLocal_Pointer(t *testing.T) { 128 tls := NewThreadLocal[*personCloneable]() 129 tls2 := NewThreadLocal[*personCloneable]() 130 // 131 assert.Nil(t, tls.Get()) 132 assert.Nil(t, tls2.Get()) 133 // 134 tls.Set(nil) 135 tls2.Set(nil) 136 assert.Nil(t, tls.Get()) 137 assert.Nil(t, tls2.Get()) 138 // 139 tls.Set(&personCloneable{Id: 1, Name: "Hello"}) 140 tls2.Set(&personCloneable{Id: 1, Name: "Hello"}) 141 assert.NotNil(t, tls.Get()) 142 assert.NotNil(t, tls2.Get()) 143 // 144 tls.Remove() 145 tls2.Remove() 146 assert.Nil(t, tls.Get()) 147 assert.Nil(t, tls2.Get()) 148 } 149 150 //=== 151 152 func TestNewThreadLocalWithInitial_Single(t *testing.T) { 153 tls := NewThreadLocalWithInitial[string](func() string { 154 return "Hello" 155 }) 156 assert.Equal(t, "Hello", tls.Get()) 157 // 158 tls2 := NewThreadLocalWithInitial[int](func() int { 159 return 22 160 }) 161 assert.Equal(t, "Hello", tls.Get()) 162 assert.Equal(t, 22, tls2.Get()) 163 // 164 tls2.Set(33) 165 assert.Equal(t, 33, tls2.Get()) 166 // 167 task := GoWait(func(token CancelToken) { 168 assert.Equal(t, "Hello", tls.Get()) 169 assert.Equal(t, 22, tls2.Get()) 170 }) 171 task.Get() 172 } 173 174 func TestNewThreadLocalWithInitial_Multi(t *testing.T) { 175 tls := NewThreadLocalWithInitial[string](func() string { 176 return "Hello" 177 }) 178 tls2 := NewThreadLocalWithInitial[int](func() int { 179 return 22 180 }) 181 tls.Set("Hello") 182 tls2.Set(22) 183 assert.Equal(t, 22, tls2.Get()) 184 assert.Equal(t, "Hello", tls.Get()) 185 // 186 tls2.Set(33) 187 assert.Equal(t, 33, tls2.Get()) 188 // 189 task := GoWait(func(token CancelToken) { 190 assert.Equal(t, "Hello", tls.Get()) 191 assert.Equal(t, 22, tls2.Get()) 192 }) 193 task.Get() 194 } 195 196 func TestNewThreadLocalWithInitial_Concurrency(t *testing.T) { 197 tls := NewThreadLocalWithInitial[any](func() any { 198 return "Hello" 199 }) 200 tls2 := NewThreadLocalWithInitial[uint64](func() uint64 { 201 return uint64(22) 202 }) 203 // 204 tls2.Set(33) 205 assert.Equal(t, uint64(33), tls2.Get()) 206 // 207 wg := &sync.WaitGroup{} 208 wg.Add(concurrency) 209 for i := 0; i < concurrency; i++ { 210 assert.Equal(t, "Hello", tls.Get()) 211 assert.Equal(t, uint64(33), tls2.Get()) 212 Go(func() { 213 assert.Equal(t, "Hello", tls.Get()) 214 assert.Equal(t, uint64(22), tls2.Get()) 215 v := rand.Uint64() 216 v2 := rand.Uint64() 217 for j := 0; j < loopTimes; j++ { 218 tls.Set(v) 219 tmp := tls.Get() 220 assert.Equal(t, v, tmp.(uint64)) 221 // 222 tls2.Set(v2) 223 tmp2 := tls2.Get() 224 assert.Equal(t, v2, tmp2) 225 } 226 wg.Done() 227 }) 228 } 229 wg.Wait() 230 // 231 task := GoWait(func(token CancelToken) { 232 assert.Equal(t, "Hello", tls.Get()) 233 assert.Equal(t, uint64(22), tls2.Get()) 234 }) 235 task.Get() 236 } 237 238 func TestNewThreadLocalWithInitial_Interface(t *testing.T) { 239 tls := NewThreadLocalWithInitial[Cloneable](func() Cloneable { 240 return nil 241 }) 242 tls2 := NewThreadLocalWithInitial[Cloneable](func() Cloneable { 243 return nil 244 }) 245 // 246 assert.Nil(t, tls.Get()) 247 assert.Nil(t, tls2.Get()) 248 // 249 tls.Set(nil) 250 tls2.Set(nil) 251 assert.Nil(t, tls.Get()) 252 assert.Nil(t, tls2.Get()) 253 // 254 tls.Set(&personCloneable{Id: 1, Name: "Hello"}) 255 tls2.Set(&personCloneable{Id: 1, Name: "Hello"}) 256 assert.NotNil(t, tls.Get()) 257 assert.NotNil(t, tls2.Get()) 258 // 259 tls.Remove() 260 tls2.Remove() 261 assert.Nil(t, tls.Get()) 262 assert.Nil(t, tls2.Get()) 263 } 264 265 func TestNewThreadLocalWithInitial_Pointer(t *testing.T) { 266 tls := NewThreadLocalWithInitial[*personCloneable](func() *personCloneable { 267 return nil 268 }) 269 tls2 := NewThreadLocalWithInitial[*personCloneable](func() *personCloneable { 270 return nil 271 }) 272 // 273 assert.Nil(t, tls.Get()) 274 assert.Nil(t, tls2.Get()) 275 // 276 tls.Set(nil) 277 tls2.Set(nil) 278 assert.Nil(t, tls.Get()) 279 assert.Nil(t, tls2.Get()) 280 // 281 tls.Set(&personCloneable{Id: 1, Name: "Hello"}) 282 tls2.Set(&personCloneable{Id: 1, Name: "Hello"}) 283 assert.NotNil(t, tls.Get()) 284 assert.NotNil(t, tls2.Get()) 285 // 286 tls.Remove() 287 tls2.Remove() 288 assert.Nil(t, tls.Get()) 289 assert.Nil(t, tls2.Get()) 290 } 291 292 //=== 293 294 func TestNewInheritableThreadLocal_Single(t *testing.T) { 295 tls := NewInheritableThreadLocal[string]() 296 tls.Set("Hello") 297 assert.Equal(t, "Hello", tls.Get()) 298 // 299 tls2 := NewInheritableThreadLocal[int]() 300 assert.Equal(t, "Hello", tls.Get()) 301 tls2.Set(22) 302 assert.Equal(t, 22, tls2.Get()) 303 // 304 tls2.Set(33) 305 assert.Equal(t, 33, tls2.Get()) 306 // 307 task := GoWait(func(token CancelToken) { 308 assert.Equal(t, "Hello", tls.Get()) 309 assert.Equal(t, 33, tls2.Get()) 310 }) 311 task.Get() 312 } 313 314 func TestNewInheritableThreadLocal_Multi(t *testing.T) { 315 tls := NewInheritableThreadLocal[string]() 316 tls2 := NewInheritableThreadLocal[int]() 317 tls.Set("Hello") 318 tls2.Set(22) 319 assert.Equal(t, 22, tls2.Get()) 320 assert.Equal(t, "Hello", tls.Get()) 321 // 322 tls2.Set(33) 323 assert.Equal(t, 33, tls2.Get()) 324 // 325 task := GoWait(func(token CancelToken) { 326 assert.Equal(t, "Hello", tls.Get()) 327 assert.Equal(t, 33, tls2.Get()) 328 }) 329 task.Get() 330 } 331 332 func TestNewInheritableThreadLocal_Concurrency(t *testing.T) { 333 tls := NewInheritableThreadLocal[uint64]() 334 tls2 := NewInheritableThreadLocal[uint64]() 335 // 336 tls2.Set(33) 337 assert.Equal(t, uint64(33), tls2.Get()) 338 // 339 wg := &sync.WaitGroup{} 340 wg.Add(concurrency) 341 for i := 0; i < concurrency; i++ { 342 assert.Equal(t, uint64(0), tls.Get()) 343 assert.Equal(t, uint64(33), tls2.Get()) 344 Go(func() { 345 assert.Equal(t, uint64(0), tls.Get()) 346 assert.Equal(t, uint64(33), tls2.Get()) 347 v := rand.Uint64() 348 v2 := rand.Uint64() 349 for j := 0; j < loopTimes; j++ { 350 tls.Set(v) 351 tmp := tls.Get() 352 assert.Equal(t, v, tmp) 353 // 354 tls2.Set(v2) 355 tmp2 := tls2.Get() 356 assert.Equal(t, v2, tmp2) 357 } 358 wg.Done() 359 }) 360 } 361 wg.Wait() 362 // 363 task := GoWait(func(token CancelToken) { 364 assert.Equal(t, uint64(0), tls.Get()) 365 assert.Equal(t, uint64(33), tls2.Get()) 366 }) 367 task.Get() 368 } 369 370 func TestNewInheritableThreadLocal_Interface(t *testing.T) { 371 tls := NewInheritableThreadLocal[Cloneable]() 372 tls2 := NewInheritableThreadLocal[Cloneable]() 373 // 374 assert.Nil(t, tls.Get()) 375 assert.Nil(t, tls2.Get()) 376 // 377 tls.Set(nil) 378 tls2.Set(nil) 379 assert.Nil(t, tls.Get()) 380 assert.Nil(t, tls2.Get()) 381 // 382 tls.Set(&personCloneable{Id: 1, Name: "Hello"}) 383 tls2.Set(&personCloneable{Id: 1, Name: "Hello"}) 384 assert.NotNil(t, tls.Get()) 385 assert.NotNil(t, tls2.Get()) 386 // 387 tls.Remove() 388 tls2.Remove() 389 assert.Nil(t, tls.Get()) 390 assert.Nil(t, tls2.Get()) 391 } 392 393 func TestNewInheritableThreadLocal_Pointer(t *testing.T) { 394 tls := NewInheritableThreadLocal[*personCloneable]() 395 tls2 := NewInheritableThreadLocal[*personCloneable]() 396 // 397 assert.Nil(t, tls.Get()) 398 assert.Nil(t, tls2.Get()) 399 // 400 tls.Set(nil) 401 tls2.Set(nil) 402 assert.Nil(t, tls.Get()) 403 assert.Nil(t, tls2.Get()) 404 // 405 tls.Set(&personCloneable{Id: 1, Name: "Hello"}) 406 tls2.Set(&personCloneable{Id: 1, Name: "Hello"}) 407 assert.NotNil(t, tls.Get()) 408 assert.NotNil(t, tls2.Get()) 409 // 410 tls.Remove() 411 tls2.Remove() 412 assert.Nil(t, tls.Get()) 413 assert.Nil(t, tls2.Get()) 414 } 415 416 //=== 417 418 func TestNewInheritableThreadLocalWithInitial_Single(t *testing.T) { 419 tls := NewInheritableThreadLocalWithInitial[string](func() string { 420 return "Hello" 421 }) 422 assert.Equal(t, "Hello", tls.Get()) 423 // 424 tls2 := NewInheritableThreadLocalWithInitial[int](func() int { 425 return 22 426 }) 427 assert.Equal(t, "Hello", tls.Get()) 428 assert.Equal(t, 22, tls2.Get()) 429 // 430 tls2.Set(33) 431 assert.Equal(t, 33, tls2.Get()) 432 // 433 task := GoWait(func(token CancelToken) { 434 assert.Equal(t, "Hello", tls.Get()) 435 assert.Equal(t, 33, tls2.Get()) 436 }) 437 task.Get() 438 } 439 440 func TestNewInheritableThreadLocalWithInitial_Multi(t *testing.T) { 441 tls := NewInheritableThreadLocalWithInitial[string](func() string { 442 return "Hello" 443 }) 444 tls2 := NewInheritableThreadLocalWithInitial[int](func() int { 445 return 22 446 }) 447 tls.Set("Hello") 448 tls2.Set(22) 449 assert.Equal(t, 22, tls2.Get()) 450 assert.Equal(t, "Hello", tls.Get()) 451 // 452 tls2.Set(33) 453 assert.Equal(t, 33, tls2.Get()) 454 // 455 task := GoWait(func(token CancelToken) { 456 assert.Equal(t, "Hello", tls.Get()) 457 assert.Equal(t, 33, tls2.Get()) 458 }) 459 task.Get() 460 } 461 462 func TestNewInheritableThreadLocalWithInitial_Concurrency(t *testing.T) { 463 tls := NewInheritableThreadLocalWithInitial[any](func() any { 464 return "Hello" 465 }) 466 tls2 := NewInheritableThreadLocalWithInitial[uint64](func() uint64 { 467 return uint64(22) 468 }) 469 // 470 tls2.Set(33) 471 assert.Equal(t, uint64(33), tls2.Get()) 472 // 473 wg := &sync.WaitGroup{} 474 wg.Add(concurrency) 475 for i := 0; i < concurrency; i++ { 476 assert.Equal(t, "Hello", tls.Get()) 477 assert.Equal(t, uint64(33), tls2.Get()) 478 Go(func() { 479 assert.Equal(t, "Hello", tls.Get()) 480 assert.Equal(t, uint64(33), tls2.Get()) 481 v := rand.Uint64() 482 v2 := rand.Uint64() 483 for j := 0; j < loopTimes; j++ { 484 tls.Set(v) 485 tmp := tls.Get() 486 assert.Equal(t, v, tmp.(uint64)) 487 // 488 tls2.Set(v2) 489 tmp2 := tls2.Get() 490 assert.Equal(t, v2, tmp2) 491 } 492 wg.Done() 493 }) 494 } 495 wg.Wait() 496 // 497 task := GoWait(func(token CancelToken) { 498 assert.Equal(t, "Hello", tls.Get()) 499 assert.Equal(t, uint64(33), tls2.Get()) 500 }) 501 task.Get() 502 } 503 504 func TestNewInheritableThreadLocalWithInitial_Interface(t *testing.T) { 505 tls := NewInheritableThreadLocalWithInitial[Cloneable](func() Cloneable { 506 return nil 507 }) 508 tls2 := NewInheritableThreadLocalWithInitial[Cloneable](func() Cloneable { 509 return nil 510 }) 511 // 512 assert.Nil(t, tls.Get()) 513 assert.Nil(t, tls2.Get()) 514 // 515 tls.Set(nil) 516 tls2.Set(nil) 517 assert.Nil(t, tls.Get()) 518 assert.Nil(t, tls2.Get()) 519 // 520 tls.Set(&personCloneable{Id: 1, Name: "Hello"}) 521 tls2.Set(&personCloneable{Id: 1, Name: "Hello"}) 522 assert.NotNil(t, tls.Get()) 523 assert.NotNil(t, tls2.Get()) 524 // 525 tls.Remove() 526 tls2.Remove() 527 assert.Nil(t, tls.Get()) 528 assert.Nil(t, tls2.Get()) 529 } 530 531 func TestNewInheritableThreadLocalWithInitial_Pointer(t *testing.T) { 532 tls := NewInheritableThreadLocalWithInitial[*personCloneable](func() *personCloneable { 533 return nil 534 }) 535 tls2 := NewInheritableThreadLocalWithInitial[*personCloneable](func() *personCloneable { 536 return nil 537 }) 538 // 539 assert.Nil(t, tls.Get()) 540 assert.Nil(t, tls2.Get()) 541 // 542 tls.Set(nil) 543 tls2.Set(nil) 544 assert.Nil(t, tls.Get()) 545 assert.Nil(t, tls2.Get()) 546 // 547 tls.Set(&personCloneable{Id: 1, Name: "Hello"}) 548 tls2.Set(&personCloneable{Id: 1, Name: "Hello"}) 549 assert.NotNil(t, tls.Get()) 550 assert.NotNil(t, tls2.Get()) 551 // 552 tls.Remove() 553 tls2.Remove() 554 assert.Nil(t, tls.Get()) 555 assert.Nil(t, tls2.Get()) 556 } 557 558 //=== 559 560 // BenchmarkThreadLocal-8 13636471 94.17 ns/op 7 B/op 0 allocs/op 561 func BenchmarkThreadLocal(b *testing.B) { 562 tlsCount := 100 563 tlsSlice := make([]ThreadLocal[int], tlsCount) 564 for i := 0; i < tlsCount; i++ { 565 tlsSlice[i] = NewThreadLocal[int]() 566 } 567 b.ReportAllocs() 568 b.ResetTimer() 569 for i := 0; i < b.N; i++ { 570 index := i % tlsCount 571 tls := tlsSlice[index] 572 initValue := tls.Get() 573 if initValue != 0 { 574 b.Fail() 575 } 576 tls.Set(i) 577 if tls.Get() != i { 578 b.Fail() 579 } 580 tls.Remove() 581 } 582 } 583 584 // BenchmarkThreadLocalWithInitial-8 13674153 86.76 ns/op 7 B/op 0 allocs/op 585 func BenchmarkThreadLocalWithInitial(b *testing.B) { 586 tlsCount := 100 587 tlsSlice := make([]ThreadLocal[int], tlsCount) 588 for i := 0; i < tlsCount; i++ { 589 index := i 590 tlsSlice[i] = NewThreadLocalWithInitial[int](func() int { 591 return index 592 }) 593 } 594 b.ReportAllocs() 595 b.ResetTimer() 596 for i := 0; i < b.N; i++ { 597 index := i % tlsCount 598 tls := tlsSlice[index] 599 initValue := tls.Get() 600 if initValue != index { 601 b.Fail() 602 } 603 tls.Set(i) 604 if tls.Get() != i { 605 b.Fail() 606 } 607 tls.Remove() 608 } 609 } 610 611 // BenchmarkInheritableThreadLocal-8 13917819 84.27 ns/op 7 B/op 0 allocs/op 612 func BenchmarkInheritableThreadLocal(b *testing.B) { 613 tlsCount := 100 614 tlsSlice := make([]ThreadLocal[int], tlsCount) 615 for i := 0; i < tlsCount; i++ { 616 tlsSlice[i] = NewInheritableThreadLocal[int]() 617 } 618 b.ReportAllocs() 619 b.ResetTimer() 620 for i := 0; i < b.N; i++ { 621 index := i % tlsCount 622 tls := tlsSlice[index] 623 initValue := tls.Get() 624 if initValue != 0 { 625 b.Fail() 626 } 627 tls.Set(i) 628 if tls.Get() != i { 629 b.Fail() 630 } 631 tls.Remove() 632 } 633 } 634 635 // BenchmarkInheritableThreadLocalWithInitial-8 13483130 90.03 ns/op 7 B/op 0 allocs/op 636 func BenchmarkInheritableThreadLocalWithInitial(b *testing.B) { 637 tlsCount := 100 638 tlsSlice := make([]ThreadLocal[int], tlsCount) 639 for i := 0; i < tlsCount; i++ { 640 index := i 641 tlsSlice[i] = NewInheritableThreadLocalWithInitial[int](func() int { 642 return index 643 }) 644 } 645 b.ReportAllocs() 646 b.ResetTimer() 647 for i := 0; i < b.N; i++ { 648 index := i % tlsCount 649 tls := tlsSlice[index] 650 initValue := tls.Get() 651 if initValue != index { 652 b.Fail() 653 } 654 tls.Set(i) 655 if tls.Get() != i { 656 b.Fail() 657 } 658 tls.Remove() 659 } 660 }