github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/common/sharded_locks_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package common
    13  
    14  import (
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/stretchr/testify/require"
    20  )
    21  
    22  func TestShardedLocks_ParallelLocksAll(t *testing.T) {
    23  	// no asserts
    24  	// ensures parallel LockAll does not fall into deadlock
    25  	count := 10
    26  	sl := NewDefaultShardedLocks()
    27  
    28  	wg := new(sync.WaitGroup)
    29  	wg.Add(count)
    30  	for i := 0; i < count; i++ {
    31  		go func() {
    32  			defer wg.Done()
    33  			sl.LockAll()
    34  			sl.UnlockAll()
    35  		}()
    36  	}
    37  	wg.Wait()
    38  }
    39  
    40  func TestShardedLocks_MixedLocks(t *testing.T) {
    41  	// no asserts
    42  	// ensures parallel LockAll + RLockAll + Lock + RLock does not fall into deadlock
    43  	count := 1000
    44  	sl := NewShardedLocks(10)
    45  
    46  	wg := new(sync.WaitGroup)
    47  	wg.Add(count)
    48  	for i := 0; i < count; i++ {
    49  		go func(i int) {
    50  			defer wg.Done()
    51  			id := uint64(i)
    52  			if i%5 == 0 {
    53  				sl.LockAll()
    54  				sl.UnlockAll()
    55  			} else {
    56  				sl.Lock(id)
    57  				sl.Unlock(id)
    58  			}
    59  		}(i)
    60  	}
    61  	wg.Wait()
    62  }
    63  
    64  func TestShardedLocks(t *testing.T) {
    65  	t.Run("Lock", func(t *testing.T) {
    66  		t.Parallel()
    67  		m := NewShardedLocks(5)
    68  
    69  		m.Lock(1)
    70  
    71  		ch := make(chan struct{})
    72  		go func() {
    73  			time.Sleep(50 * time.Millisecond)
    74  			m.Unlock(1)
    75  
    76  			close(ch)
    77  		}()
    78  
    79  		m.Lock(1)
    80  
    81  		select {
    82  		case <-ch:
    83  		case <-time.After(1 * time.Second):
    84  			require.Fail(t, "should be unlocked")
    85  		}
    86  
    87  		m.Unlock(1)
    88  	})
    89  
    90  	t.Run("Lock blocks LockAll", func(t *testing.T) {
    91  		t.Parallel()
    92  		m := NewShardedLocks(5)
    93  
    94  		m.Lock(1)
    95  
    96  		ch := make(chan struct{})
    97  		go func() {
    98  			time.Sleep(50 * time.Millisecond)
    99  			m.Unlock(1)
   100  
   101  			close(ch)
   102  		}()
   103  
   104  		m.LockAll()
   105  
   106  		select {
   107  		case <-ch:
   108  		case <-time.After(1 * time.Second):
   109  			require.Fail(t, "should be unlocked")
   110  		}
   111  
   112  		m.UnlockAll()
   113  	})
   114  
   115  	t.Run("LockAll blocks Lock", func(t *testing.T) {
   116  		t.Parallel()
   117  		m := NewShardedLocks(5)
   118  
   119  		m.LockAll()
   120  
   121  		ch := make(chan struct{})
   122  		go func() {
   123  			time.Sleep(50 * time.Millisecond)
   124  			m.UnlockAll()
   125  
   126  			close(ch)
   127  		}()
   128  
   129  		m.Lock(1)
   130  
   131  		select {
   132  		case <-ch:
   133  		case <-time.After(1 * time.Second):
   134  			require.Fail(t, "should be unlocked")
   135  		}
   136  
   137  		m.Unlock(1)
   138  	})
   139  
   140  	t.Run("LockAll blocks LockAll", func(t *testing.T) {
   141  		t.Parallel()
   142  		m := NewShardedLocks(5)
   143  
   144  		m.LockAll()
   145  
   146  		ch := make(chan struct{})
   147  		go func() {
   148  			time.Sleep(50 * time.Millisecond)
   149  			m.UnlockAll()
   150  
   151  			close(ch)
   152  		}()
   153  
   154  		m.LockAll()
   155  
   156  		select {
   157  		case <-ch:
   158  		case <-time.After(1 * time.Second):
   159  			require.Fail(t, "should be unlocked")
   160  		}
   161  
   162  		m.UnlockAll()
   163  	})
   164  
   165  	t.Run("UnlockAll releases all locks", func(t *testing.T) {
   166  		t.Parallel()
   167  		m := NewShardedLocks(5)
   168  
   169  		m.LockAll()
   170  		m.UnlockAll()
   171  
   172  		m.Lock(1)
   173  		m.Unlock(1)
   174  	})
   175  
   176  	t.Run("unlock should wake up next waiting lock", func(t *testing.T) {
   177  		t.Parallel()
   178  		m := NewShardedLocks(2)
   179  
   180  		m.Lock(1)
   181  
   182  		ch1 := make(chan struct{})
   183  		ch2 := make(chan struct{})
   184  
   185  		go func() {
   186  			defer close(ch1)
   187  
   188  			m.Lock(1)
   189  		}()
   190  
   191  		go func() {
   192  			defer close(ch2)
   193  
   194  			time.Sleep(100 * time.Millisecond)
   195  			m.Lock(1)
   196  		}()
   197  
   198  		time.Sleep(10 * time.Millisecond)
   199  		m.Unlock(1)
   200  
   201  		<-ch1
   202  
   203  		m.Unlock(1)
   204  
   205  		<-ch2
   206  
   207  		m.Unlock(1)
   208  	})
   209  }
   210  
   211  func TestShardedRWLocks_ParallelLocksAll(t *testing.T) {
   212  	// no asserts
   213  	// ensures parallel LockAll does not fall into deadlock
   214  	count := 10
   215  	sl := NewDefaultShardedRWLocks()
   216  
   217  	wg := new(sync.WaitGroup)
   218  	wg.Add(count)
   219  	for i := 0; i < count; i++ {
   220  		go func() {
   221  			defer wg.Done()
   222  			sl.LockAll()
   223  			sl.UnlockAll()
   224  		}()
   225  	}
   226  	wg.Wait()
   227  }
   228  
   229  func TestShardedRWLocks_ParallelRLocksAll(t *testing.T) {
   230  	// no asserts
   231  	// ensures parallel RLockAll does not fall into deadlock
   232  	count := 10
   233  	sl := NewDefaultShardedRWLocks()
   234  
   235  	wg := new(sync.WaitGroup)
   236  	wg.Add(count)
   237  	for i := 0; i < count; i++ {
   238  		go func() {
   239  			defer wg.Done()
   240  			sl.RLockAll()
   241  			sl.RUnlockAll()
   242  		}()
   243  	}
   244  	wg.Wait()
   245  }
   246  
   247  func TestShardedRWLocks_ParallelLocksAllAndRLocksAll(t *testing.T) {
   248  	// no asserts
   249  	// ensures parallel LockAll + RLockAll does not fall into deadlock
   250  	count := 50
   251  	sl := NewDefaultShardedRWLocks()
   252  
   253  	wg := new(sync.WaitGroup)
   254  	wg.Add(count)
   255  	for i := 0; i < count; i++ {
   256  		go func(i int) {
   257  			defer wg.Done()
   258  			if i%2 == 0 {
   259  				sl.LockAll()
   260  				sl.UnlockAll()
   261  			} else {
   262  				sl.RLockAll()
   263  				sl.RUnlockAll()
   264  			}
   265  		}(i)
   266  	}
   267  	wg.Wait()
   268  }
   269  
   270  func TestShardedRWLocks_MixedLocks(t *testing.T) {
   271  	// no asserts
   272  	// ensures parallel LockAll + RLockAll + Lock + RLock does not fall into deadlock
   273  	count := 1000
   274  	sl := NewShardedRWLocks(10)
   275  
   276  	wg := new(sync.WaitGroup)
   277  	wg.Add(count)
   278  	for i := 0; i < count; i++ {
   279  		go func(i int) {
   280  			defer wg.Done()
   281  			id := uint64(i)
   282  			if i%5 == 0 {
   283  				if i%2 == 0 {
   284  					sl.LockAll()
   285  					sl.UnlockAll()
   286  				} else {
   287  					sl.RLockAll()
   288  					sl.RUnlockAll()
   289  				}
   290  			} else {
   291  				if i%2 == 0 {
   292  					sl.Lock(id)
   293  					sl.Unlock(id)
   294  				} else {
   295  					sl.RLock(id)
   296  					sl.RUnlock(id)
   297  				}
   298  			}
   299  		}(i)
   300  	}
   301  	wg.Wait()
   302  }
   303  
   304  func TestShardedRWLocks(t *testing.T) {
   305  	t.Run("RLock", func(t *testing.T) {
   306  		t.Parallel()
   307  		m := NewShardedRWLocks(5)
   308  
   309  		m.RLock(1)
   310  		m.RLock(1)
   311  
   312  		m.RUnlock(1)
   313  		m.RUnlock(1)
   314  	})
   315  
   316  	t.Run("Lock", func(t *testing.T) {
   317  		t.Parallel()
   318  		m := NewShardedRWLocks(5)
   319  
   320  		m.Lock(1)
   321  
   322  		ch := make(chan struct{})
   323  		go func() {
   324  			time.Sleep(50 * time.Millisecond)
   325  			m.Unlock(1)
   326  
   327  			close(ch)
   328  		}()
   329  
   330  		m.Lock(1)
   331  
   332  		select {
   333  		case <-ch:
   334  		case <-time.After(1 * time.Second):
   335  			require.Fail(t, "should be unlocked")
   336  		}
   337  
   338  		m.Unlock(1)
   339  	})
   340  
   341  	t.Run("RLock blocks Lock", func(t *testing.T) {
   342  		t.Parallel()
   343  		m := NewShardedRWLocks(5)
   344  
   345  		m.RLock(1)
   346  
   347  		ch := make(chan struct{})
   348  		go func() {
   349  			time.Sleep(50 * time.Millisecond)
   350  			m.RUnlock(1)
   351  
   352  			close(ch)
   353  		}()
   354  
   355  		m.Lock(1)
   356  
   357  		select {
   358  		case <-ch:
   359  		case <-time.After(1 * time.Second):
   360  			require.Fail(t, "should be unlocked")
   361  		}
   362  
   363  		m.Unlock(1)
   364  	})
   365  
   366  	t.Run("Lock blocks RLock", func(t *testing.T) {
   367  		t.Parallel()
   368  		m := NewShardedRWLocks(5)
   369  
   370  		m.Lock(1)
   371  
   372  		ch := make(chan struct{})
   373  		go func() {
   374  			time.Sleep(50 * time.Millisecond)
   375  			m.Unlock(1)
   376  
   377  			close(ch)
   378  		}()
   379  
   380  		m.RLock(1)
   381  
   382  		select {
   383  		case <-ch:
   384  		default:
   385  			require.Fail(t, "should be unlocked")
   386  		}
   387  
   388  		m.RUnlock(1)
   389  	})
   390  
   391  	t.Run("Lock blocks LockAll", func(t *testing.T) {
   392  		t.Parallel()
   393  		m := NewShardedRWLocks(5)
   394  
   395  		m.Lock(1)
   396  
   397  		ch := make(chan struct{})
   398  		go func() {
   399  			time.Sleep(50 * time.Millisecond)
   400  			m.Unlock(1)
   401  
   402  			close(ch)
   403  		}()
   404  
   405  		m.LockAll()
   406  
   407  		select {
   408  		case <-ch:
   409  		case <-time.After(1 * time.Second):
   410  			require.Fail(t, "should be unlocked")
   411  		}
   412  
   413  		m.UnlockAll()
   414  	})
   415  
   416  	t.Run("LockAll blocks Lock", func(t *testing.T) {
   417  		t.Parallel()
   418  		m := NewShardedRWLocks(5)
   419  
   420  		m.LockAll()
   421  
   422  		ch := make(chan struct{})
   423  		go func() {
   424  			time.Sleep(50 * time.Millisecond)
   425  			m.UnlockAll()
   426  
   427  			close(ch)
   428  		}()
   429  
   430  		m.Lock(1)
   431  
   432  		select {
   433  		case <-ch:
   434  		case <-time.After(1 * time.Second):
   435  			require.Fail(t, "should be unlocked")
   436  		}
   437  
   438  		m.Unlock(1)
   439  	})
   440  
   441  	t.Run("LockAll blocks RLock", func(t *testing.T) {
   442  		t.Parallel()
   443  		m := NewShardedRWLocks(5)
   444  
   445  		m.LockAll()
   446  
   447  		ch := make(chan struct{})
   448  		go func() {
   449  			time.Sleep(50 * time.Millisecond)
   450  			m.UnlockAll()
   451  
   452  			close(ch)
   453  		}()
   454  
   455  		m.RLock(1)
   456  
   457  		select {
   458  		case <-ch:
   459  		case <-time.After(1 * time.Second):
   460  			require.Fail(t, "should be unlocked")
   461  		}
   462  
   463  		m.RUnlock(1)
   464  	})
   465  
   466  	t.Run("LockAll blocks LockAll", func(t *testing.T) {
   467  		t.Parallel()
   468  		m := NewShardedRWLocks(5)
   469  
   470  		m.LockAll()
   471  
   472  		ch := make(chan struct{})
   473  		go func() {
   474  			time.Sleep(50 * time.Millisecond)
   475  			m.UnlockAll()
   476  
   477  			close(ch)
   478  		}()
   479  
   480  		m.LockAll()
   481  
   482  		select {
   483  		case <-ch:
   484  		case <-time.After(1 * time.Second):
   485  			require.Fail(t, "should be unlocked")
   486  		}
   487  
   488  		m.UnlockAll()
   489  	})
   490  
   491  	t.Run("UnlockAll releases all locks", func(t *testing.T) {
   492  		t.Parallel()
   493  		m := NewShardedRWLocks(5)
   494  
   495  		m.LockAll()
   496  		m.UnlockAll()
   497  
   498  		m.Lock(1)
   499  		m.Unlock(1)
   500  
   501  		m.RLock(1)
   502  		m.RUnlock(1)
   503  	})
   504  
   505  	t.Run("RLockAll blocks Lock", func(t *testing.T) {
   506  		t.Parallel()
   507  		m := NewShardedRWLocks(5)
   508  
   509  		m.RLockAll()
   510  
   511  		ch := make(chan struct{})
   512  		go func() {
   513  			time.Sleep(50 * time.Millisecond)
   514  			m.RUnlockAll()
   515  
   516  			close(ch)
   517  		}()
   518  
   519  		m.Lock(1)
   520  
   521  		select {
   522  		case <-ch:
   523  		case <-time.After(1 * time.Second):
   524  			require.Fail(t, "should be unlocked")
   525  		}
   526  
   527  		m.Unlock(1)
   528  	})
   529  
   530  	t.Run("RLockAll doesn't block/unblock RLock", func(t *testing.T) {
   531  		t.Parallel()
   532  		m := NewShardedRWLocks(5)
   533  
   534  		m.RLockAll()
   535  		m.RLock(1)
   536  
   537  		m.RUnlockAll()
   538  		m.RUnlock(1)
   539  	})
   540  
   541  	t.Run("RLockAll blocks LockAll", func(t *testing.T) {
   542  		t.Parallel()
   543  		m := NewShardedRWLocks(5)
   544  
   545  		m.RLockAll()
   546  
   547  		ch := make(chan struct{})
   548  		go func() {
   549  			time.Sleep(50 * time.Millisecond)
   550  			m.RUnlockAll()
   551  
   552  			close(ch)
   553  		}()
   554  
   555  		m.LockAll()
   556  
   557  		select {
   558  		case <-ch:
   559  		case <-time.After(1 * time.Second):
   560  			require.Fail(t, "should be unlocked")
   561  		}
   562  
   563  		m.UnlockAll()
   564  	})
   565  
   566  	t.Run("RLockAll doesn't block RLockAll", func(t *testing.T) {
   567  		t.Parallel()
   568  		m := NewShardedRWLocks(5)
   569  
   570  		m.RLockAll()
   571  		m.RLockAll()
   572  
   573  		m.RUnlockAll()
   574  		m.RUnlockAll()
   575  	})
   576  
   577  	t.Run("unlock should wake up next waiting lock", func(t *testing.T) {
   578  		t.Parallel()
   579  		m := NewShardedRWLocks(2)
   580  
   581  		m.RLock(1)
   582  
   583  		ch1 := make(chan struct{})
   584  		ch2 := make(chan struct{})
   585  
   586  		go func() {
   587  			defer close(ch1)
   588  
   589  			m.Lock(1)
   590  		}()
   591  
   592  		go func() {
   593  			defer close(ch2)
   594  
   595  			time.Sleep(100 * time.Millisecond)
   596  			m.Lock(1)
   597  		}()
   598  
   599  		time.Sleep(10 * time.Millisecond)
   600  		m.RUnlock(1)
   601  
   602  		<-ch1
   603  
   604  		m.Unlock(1)
   605  
   606  		<-ch2
   607  
   608  		m.Unlock(1)
   609  	})
   610  }