github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/xsync/mutex_test.go (about)

     1  package xsync
     2  
     3  import (
     4  	"runtime"
     5  	"sync"
     6  	"sync/atomic"
     7  	"testing"
     8  
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
    12  )
    13  
    14  func TestMutex(t *testing.T) {
    15  	xtest.TestManyTimes(t, func(t testing.TB) {
    16  		var m Mutex
    17  		a, b := 1, 1
    18  
    19  		var wg sync.WaitGroup
    20  		f := func() {
    21  			defer wg.Done()
    22  
    23  			if a+b == 2 {
    24  				a = 2
    25  			} else {
    26  				b = 2
    27  			}
    28  		}
    29  
    30  		wg.Add(2)
    31  		go m.WithLock(f)
    32  		go m.WithLock(f)
    33  
    34  		wg.Wait()
    35  		require.Equal(t, 2, a)
    36  		require.Equal(t, 2, b)
    37  	})
    38  }
    39  
    40  func TestRWMutex(t *testing.T) {
    41  	xtest.TestManyTimesWithName(t, "WithLock", func(t testing.TB) {
    42  		var m Mutex
    43  		a, b := 1, 1
    44  
    45  		var wg sync.WaitGroup
    46  		f := func() {
    47  			defer wg.Done()
    48  
    49  			if a+b == 2 {
    50  				a = 2
    51  			} else {
    52  				b = 2
    53  			}
    54  		}
    55  
    56  		wg.Add(2)
    57  		go m.WithLock(f)
    58  		go m.WithLock(f)
    59  
    60  		wg.Wait()
    61  		require.Equal(t, 2, a)
    62  		require.Equal(t, 2, b)
    63  	})
    64  	xtest.TestManyTimesWithName(t, "WithRLock", func(t testing.TB) {
    65  		var m RWMutex
    66  		a, b := 1, 1
    67  
    68  		var badSummCount int64
    69  		var wg sync.WaitGroup
    70  
    71  		for reader := 0; reader < 100; reader++ {
    72  			wg.Add(1)
    73  			go func() {
    74  				defer wg.Done()
    75  
    76  				for i := 0; i < 1000; i++ {
    77  					m.WithRLock(func() {
    78  						if a+b != 2 {
    79  							atomic.AddInt64(&badSummCount, 1)
    80  						}
    81  					})
    82  					runtime.Gosched()
    83  				}
    84  			}()
    85  		}
    86  
    87  		wg.Add(1)
    88  		go func() {
    89  			defer wg.Done()
    90  
    91  			for i := 0; i < 100; i++ {
    92  				m.WithLock(func() {
    93  					a++
    94  					b--
    95  				})
    96  				runtime.Gosched()
    97  			}
    98  		}()
    99  
   100  		wg.Wait()
   101  		require.Equal(t, 2, a+b)
   102  		require.Equal(t, int64(0), badSummCount)
   103  	})
   104  }