github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/common/sharded_locks_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package common 13 14 import ( 15 "sync" 16 "testing" 17 "time" 18 19 "github.com/stretchr/testify/require" 20 ) 21 22 func TestShardedLocks_ParallelLocksAll(t *testing.T) { 23 // no asserts 24 // ensures parallel LockAll does not fall into deadlock 25 count := 10 26 sl := NewDefaultShardedLocks() 27 28 wg := new(sync.WaitGroup) 29 wg.Add(count) 30 for i := 0; i < count; i++ { 31 go func() { 32 defer wg.Done() 33 sl.LockAll() 34 sl.UnlockAll() 35 }() 36 } 37 wg.Wait() 38 } 39 40 func TestShardedLocks_MixedLocks(t *testing.T) { 41 // no asserts 42 // ensures parallel LockAll + RLockAll + Lock + RLock does not fall into deadlock 43 count := 1000 44 sl := NewShardedLocks(10) 45 46 wg := new(sync.WaitGroup) 47 wg.Add(count) 48 for i := 0; i < count; i++ { 49 go func(i int) { 50 defer wg.Done() 51 id := uint64(i) 52 if i%5 == 0 { 53 sl.LockAll() 54 sl.UnlockAll() 55 } else { 56 sl.Lock(id) 57 sl.Unlock(id) 58 } 59 }(i) 60 } 61 wg.Wait() 62 } 63 64 func TestShardedLocks(t *testing.T) { 65 t.Run("Lock", func(t *testing.T) { 66 t.Parallel() 67 m := NewShardedLocks(5) 68 69 m.Lock(1) 70 71 ch := make(chan struct{}) 72 go func() { 73 time.Sleep(50 * time.Millisecond) 74 m.Unlock(1) 75 76 close(ch) 77 }() 78 79 m.Lock(1) 80 81 select { 82 case <-ch: 83 case <-time.After(1 * time.Second): 84 require.Fail(t, "should be unlocked") 85 } 86 87 m.Unlock(1) 88 }) 89 90 t.Run("Lock blocks LockAll", func(t *testing.T) { 91 t.Parallel() 92 m := NewShardedLocks(5) 93 94 m.Lock(1) 95 96 ch := make(chan struct{}) 97 go func() { 98 time.Sleep(50 * time.Millisecond) 99 m.Unlock(1) 100 101 close(ch) 102 }() 103 104 m.LockAll() 105 106 select { 107 case <-ch: 108 case <-time.After(1 * time.Second): 109 require.Fail(t, "should be unlocked") 110 } 111 112 m.UnlockAll() 113 }) 114 115 t.Run("LockAll blocks Lock", func(t *testing.T) { 116 t.Parallel() 117 m := NewShardedLocks(5) 118 119 m.LockAll() 120 121 ch := make(chan struct{}) 122 go func() { 123 time.Sleep(50 * time.Millisecond) 124 m.UnlockAll() 125 126 close(ch) 127 }() 128 129 m.Lock(1) 130 131 select { 132 case <-ch: 133 case <-time.After(1 * time.Second): 134 require.Fail(t, "should be unlocked") 135 } 136 137 m.Unlock(1) 138 }) 139 140 t.Run("LockAll blocks LockAll", func(t *testing.T) { 141 t.Parallel() 142 m := NewShardedLocks(5) 143 144 m.LockAll() 145 146 ch := make(chan struct{}) 147 go func() { 148 time.Sleep(50 * time.Millisecond) 149 m.UnlockAll() 150 151 close(ch) 152 }() 153 154 m.LockAll() 155 156 select { 157 case <-ch: 158 case <-time.After(1 * time.Second): 159 require.Fail(t, "should be unlocked") 160 } 161 162 m.UnlockAll() 163 }) 164 165 t.Run("UnlockAll releases all locks", func(t *testing.T) { 166 t.Parallel() 167 m := NewShardedLocks(5) 168 169 m.LockAll() 170 m.UnlockAll() 171 172 m.Lock(1) 173 m.Unlock(1) 174 }) 175 176 t.Run("unlock should wake up next waiting lock", func(t *testing.T) { 177 t.Parallel() 178 m := NewShardedLocks(2) 179 180 m.Lock(1) 181 182 ch1 := make(chan struct{}) 183 ch2 := make(chan struct{}) 184 185 go func() { 186 defer close(ch1) 187 188 m.Lock(1) 189 }() 190 191 go func() { 192 defer close(ch2) 193 194 time.Sleep(100 * time.Millisecond) 195 m.Lock(1) 196 }() 197 198 time.Sleep(10 * time.Millisecond) 199 m.Unlock(1) 200 201 <-ch1 202 203 m.Unlock(1) 204 205 <-ch2 206 207 m.Unlock(1) 208 }) 209 } 210 211 func TestShardedRWLocks_ParallelLocksAll(t *testing.T) { 212 // no asserts 213 // ensures parallel LockAll does not fall into deadlock 214 count := 10 215 sl := NewDefaultShardedRWLocks() 216 217 wg := new(sync.WaitGroup) 218 wg.Add(count) 219 for i := 0; i < count; i++ { 220 go func() { 221 defer wg.Done() 222 sl.LockAll() 223 sl.UnlockAll() 224 }() 225 } 226 wg.Wait() 227 } 228 229 func TestShardedRWLocks_ParallelRLocksAll(t *testing.T) { 230 // no asserts 231 // ensures parallel RLockAll does not fall into deadlock 232 count := 10 233 sl := NewDefaultShardedRWLocks() 234 235 wg := new(sync.WaitGroup) 236 wg.Add(count) 237 for i := 0; i < count; i++ { 238 go func() { 239 defer wg.Done() 240 sl.RLockAll() 241 sl.RUnlockAll() 242 }() 243 } 244 wg.Wait() 245 } 246 247 func TestShardedRWLocks_ParallelLocksAllAndRLocksAll(t *testing.T) { 248 // no asserts 249 // ensures parallel LockAll + RLockAll does not fall into deadlock 250 count := 50 251 sl := NewDefaultShardedRWLocks() 252 253 wg := new(sync.WaitGroup) 254 wg.Add(count) 255 for i := 0; i < count; i++ { 256 go func(i int) { 257 defer wg.Done() 258 if i%2 == 0 { 259 sl.LockAll() 260 sl.UnlockAll() 261 } else { 262 sl.RLockAll() 263 sl.RUnlockAll() 264 } 265 }(i) 266 } 267 wg.Wait() 268 } 269 270 func TestShardedRWLocks_MixedLocks(t *testing.T) { 271 // no asserts 272 // ensures parallel LockAll + RLockAll + Lock + RLock does not fall into deadlock 273 count := 1000 274 sl := NewShardedRWLocks(10) 275 276 wg := new(sync.WaitGroup) 277 wg.Add(count) 278 for i := 0; i < count; i++ { 279 go func(i int) { 280 defer wg.Done() 281 id := uint64(i) 282 if i%5 == 0 { 283 if i%2 == 0 { 284 sl.LockAll() 285 sl.UnlockAll() 286 } else { 287 sl.RLockAll() 288 sl.RUnlockAll() 289 } 290 } else { 291 if i%2 == 0 { 292 sl.Lock(id) 293 sl.Unlock(id) 294 } else { 295 sl.RLock(id) 296 sl.RUnlock(id) 297 } 298 } 299 }(i) 300 } 301 wg.Wait() 302 } 303 304 func TestShardedRWLocks(t *testing.T) { 305 t.Run("RLock", func(t *testing.T) { 306 t.Parallel() 307 m := NewShardedRWLocks(5) 308 309 m.RLock(1) 310 m.RLock(1) 311 312 m.RUnlock(1) 313 m.RUnlock(1) 314 }) 315 316 t.Run("Lock", func(t *testing.T) { 317 t.Parallel() 318 m := NewShardedRWLocks(5) 319 320 m.Lock(1) 321 322 ch := make(chan struct{}) 323 go func() { 324 time.Sleep(50 * time.Millisecond) 325 m.Unlock(1) 326 327 close(ch) 328 }() 329 330 m.Lock(1) 331 332 select { 333 case <-ch: 334 case <-time.After(1 * time.Second): 335 require.Fail(t, "should be unlocked") 336 } 337 338 m.Unlock(1) 339 }) 340 341 t.Run("RLock blocks Lock", func(t *testing.T) { 342 t.Parallel() 343 m := NewShardedRWLocks(5) 344 345 m.RLock(1) 346 347 ch := make(chan struct{}) 348 go func() { 349 time.Sleep(50 * time.Millisecond) 350 m.RUnlock(1) 351 352 close(ch) 353 }() 354 355 m.Lock(1) 356 357 select { 358 case <-ch: 359 case <-time.After(1 * time.Second): 360 require.Fail(t, "should be unlocked") 361 } 362 363 m.Unlock(1) 364 }) 365 366 t.Run("Lock blocks RLock", func(t *testing.T) { 367 t.Parallel() 368 m := NewShardedRWLocks(5) 369 370 m.Lock(1) 371 372 ch := make(chan struct{}) 373 go func() { 374 time.Sleep(50 * time.Millisecond) 375 m.Unlock(1) 376 377 close(ch) 378 }() 379 380 m.RLock(1) 381 382 select { 383 case <-ch: 384 default: 385 require.Fail(t, "should be unlocked") 386 } 387 388 m.RUnlock(1) 389 }) 390 391 t.Run("Lock blocks LockAll", func(t *testing.T) { 392 t.Parallel() 393 m := NewShardedRWLocks(5) 394 395 m.Lock(1) 396 397 ch := make(chan struct{}) 398 go func() { 399 time.Sleep(50 * time.Millisecond) 400 m.Unlock(1) 401 402 close(ch) 403 }() 404 405 m.LockAll() 406 407 select { 408 case <-ch: 409 case <-time.After(1 * time.Second): 410 require.Fail(t, "should be unlocked") 411 } 412 413 m.UnlockAll() 414 }) 415 416 t.Run("LockAll blocks Lock", func(t *testing.T) { 417 t.Parallel() 418 m := NewShardedRWLocks(5) 419 420 m.LockAll() 421 422 ch := make(chan struct{}) 423 go func() { 424 time.Sleep(50 * time.Millisecond) 425 m.UnlockAll() 426 427 close(ch) 428 }() 429 430 m.Lock(1) 431 432 select { 433 case <-ch: 434 case <-time.After(1 * time.Second): 435 require.Fail(t, "should be unlocked") 436 } 437 438 m.Unlock(1) 439 }) 440 441 t.Run("LockAll blocks RLock", func(t *testing.T) { 442 t.Parallel() 443 m := NewShardedRWLocks(5) 444 445 m.LockAll() 446 447 ch := make(chan struct{}) 448 go func() { 449 time.Sleep(50 * time.Millisecond) 450 m.UnlockAll() 451 452 close(ch) 453 }() 454 455 m.RLock(1) 456 457 select { 458 case <-ch: 459 case <-time.After(1 * time.Second): 460 require.Fail(t, "should be unlocked") 461 } 462 463 m.RUnlock(1) 464 }) 465 466 t.Run("LockAll blocks LockAll", func(t *testing.T) { 467 t.Parallel() 468 m := NewShardedRWLocks(5) 469 470 m.LockAll() 471 472 ch := make(chan struct{}) 473 go func() { 474 time.Sleep(50 * time.Millisecond) 475 m.UnlockAll() 476 477 close(ch) 478 }() 479 480 m.LockAll() 481 482 select { 483 case <-ch: 484 case <-time.After(1 * time.Second): 485 require.Fail(t, "should be unlocked") 486 } 487 488 m.UnlockAll() 489 }) 490 491 t.Run("UnlockAll releases all locks", func(t *testing.T) { 492 t.Parallel() 493 m := NewShardedRWLocks(5) 494 495 m.LockAll() 496 m.UnlockAll() 497 498 m.Lock(1) 499 m.Unlock(1) 500 501 m.RLock(1) 502 m.RUnlock(1) 503 }) 504 505 t.Run("RLockAll blocks Lock", func(t *testing.T) { 506 t.Parallel() 507 m := NewShardedRWLocks(5) 508 509 m.RLockAll() 510 511 ch := make(chan struct{}) 512 go func() { 513 time.Sleep(50 * time.Millisecond) 514 m.RUnlockAll() 515 516 close(ch) 517 }() 518 519 m.Lock(1) 520 521 select { 522 case <-ch: 523 case <-time.After(1 * time.Second): 524 require.Fail(t, "should be unlocked") 525 } 526 527 m.Unlock(1) 528 }) 529 530 t.Run("RLockAll doesn't block/unblock RLock", func(t *testing.T) { 531 t.Parallel() 532 m := NewShardedRWLocks(5) 533 534 m.RLockAll() 535 m.RLock(1) 536 537 m.RUnlockAll() 538 m.RUnlock(1) 539 }) 540 541 t.Run("RLockAll blocks LockAll", func(t *testing.T) { 542 t.Parallel() 543 m := NewShardedRWLocks(5) 544 545 m.RLockAll() 546 547 ch := make(chan struct{}) 548 go func() { 549 time.Sleep(50 * time.Millisecond) 550 m.RUnlockAll() 551 552 close(ch) 553 }() 554 555 m.LockAll() 556 557 select { 558 case <-ch: 559 case <-time.After(1 * time.Second): 560 require.Fail(t, "should be unlocked") 561 } 562 563 m.UnlockAll() 564 }) 565 566 t.Run("RLockAll doesn't block RLockAll", func(t *testing.T) { 567 t.Parallel() 568 m := NewShardedRWLocks(5) 569 570 m.RLockAll() 571 m.RLockAll() 572 573 m.RUnlockAll() 574 m.RUnlockAll() 575 }) 576 577 t.Run("unlock should wake up next waiting lock", func(t *testing.T) { 578 t.Parallel() 579 m := NewShardedRWLocks(2) 580 581 m.RLock(1) 582 583 ch1 := make(chan struct{}) 584 ch2 := make(chan struct{}) 585 586 go func() { 587 defer close(ch1) 588 589 m.Lock(1) 590 }() 591 592 go func() { 593 defer close(ch2) 594 595 time.Sleep(100 * time.Millisecond) 596 m.Lock(1) 597 }() 598 599 time.Sleep(10 * time.Millisecond) 600 m.RUnlock(1) 601 602 <-ch1 603 604 m.Unlock(1) 605 606 <-ch2 607 608 m.Unlock(1) 609 }) 610 }