github.com/zhiqiangxu/util@v0.0.0-20230112053021-0a7aee056cd5/wm/max_test.go (about)

     1  package wm
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"runtime"
     7  	"sync"
     8  	"sync/atomic"
     9  	"testing"
    10  	"time"
    11  
    12  	"golang.org/x/sync/semaphore"
    13  )
    14  
    15  func TestBasicEnter(t *testing.T) {
    16  	m := NewMax(1)
    17  
    18  	m.Enter(context.Background())
    19  
    20  	// Try blocking lock the mutex from a different goroutine. This must
    21  	// not block because the mutex is held.
    22  	ch := make(chan struct{}, 1)
    23  	go func() {
    24  		m.Enter(context.Background())
    25  		ch <- struct{}{}
    26  		m.Exit()
    27  		ch <- struct{}{}
    28  	}()
    29  
    30  	select {
    31  	case <-ch:
    32  		t.Fatalf("Lock succeeded on locked mutex")
    33  	case <-time.After(100 * time.Millisecond):
    34  	}
    35  
    36  	// Unlock the mutex and make sure that the goroutine waiting on Lock()
    37  	// unblocks and succeeds.
    38  	m.Exit()
    39  
    40  	select {
    41  	case <-ch:
    42  	case <-time.After(100 * time.Millisecond):
    43  		t.Fatalf("Lock failed to acquire unlocked mutex")
    44  	}
    45  
    46  	// Make sure we can lock and unlock again.
    47  	m.Enter(context.Background())
    48  	m.Exit()
    49  }
    50  
    51  func TestTryEnter(t *testing.T) {
    52  	m := NewMax(1)
    53  
    54  	// Try to lock. It should succeed.
    55  	if !m.TryEnter() {
    56  		t.Fatalf("TryEnter failed on unlocked mutex")
    57  	}
    58  
    59  	// Try to lock again, it should now fail.
    60  	if m.TryEnter() {
    61  		t.Fatalf("TryEnter succeeded on locked mutex")
    62  	}
    63  
    64  	// Try blocking lock the mutex from a different goroutine. This must
    65  	// not block because the mutex is held.
    66  	ch := make(chan struct{}, 1)
    67  	go func() {
    68  		m.Enter(context.Background())
    69  		ch <- struct{}{}
    70  		m.Exit()
    71  	}()
    72  
    73  	select {
    74  	case <-ch:
    75  		t.Fatalf("Lock succeeded on locked mutex")
    76  	case <-time.After(100 * time.Millisecond):
    77  	}
    78  
    79  	// Unlock the mutex and make sure that the goroutine waiting on Lock()
    80  	// unblocks and succeeds.
    81  	m.Exit()
    82  
    83  	select {
    84  	case <-ch:
    85  	case <-time.After(100 * time.Millisecond):
    86  		t.Fatalf("Lock failed to acquire unlocked mutex")
    87  	}
    88  }
    89  
    90  func TestMutualExclusion(t *testing.T) {
    91  	m := NewMax(1)
    92  
    93  	// Test mutual exclusion by running "gr" goroutines concurrently, and
    94  	// have each one increment a counter "iters" times within the critical
    95  	// section established by the mutex.
    96  	//
    97  	// If at the end the counter is not gr * iters, then we know that
    98  	// goroutines ran concurrently within the critical section.
    99  	//
   100  	// If one of the goroutines doesn't complete, it's likely a bug that
   101  	// causes to it to wait forever.
   102  	const gr = 100
   103  	const iters = 100000
   104  	v := 0
   105  	var wg sync.WaitGroup
   106  	for i := 0; i < gr; i++ {
   107  		wg.Add(1)
   108  		go func() {
   109  			for j := 0; j < iters; j++ {
   110  				m.Enter(context.Background())
   111  				v++
   112  				m.Exit()
   113  			}
   114  			wg.Done()
   115  		}()
   116  	}
   117  
   118  	wg.Wait()
   119  
   120  	if v != gr*iters {
   121  		t.Fatalf("Bad count: got %v, want %v", v, gr*iters)
   122  	}
   123  }
   124  
   125  func TestMutualExclusionWithTryEnter(t *testing.T) {
   126  	m := NewMax(1)
   127  
   128  	// Similar to the previous, with the addition of some goroutines that
   129  	// only increment the count if TryEnter succeeds.
   130  	const gr = 100
   131  	const iters = 100000
   132  	total := int64(gr * iters)
   133  	var tryTotal int64
   134  	v := int64(0)
   135  	var wg sync.WaitGroup
   136  	for i := 0; i < gr; i++ {
   137  		wg.Add(2)
   138  		go func() {
   139  			for j := 0; j < iters; j++ {
   140  				m.Enter(context.Background())
   141  				v++
   142  				m.Exit()
   143  			}
   144  			wg.Done()
   145  		}()
   146  		go func() {
   147  			local := int64(0)
   148  			for j := 0; j < iters; j++ {
   149  				if m.TryEnter() {
   150  					v++
   151  					m.Exit()
   152  					local++
   153  				}
   154  			}
   155  			atomic.AddInt64(&tryTotal, local)
   156  			wg.Done()
   157  		}()
   158  	}
   159  
   160  	wg.Wait()
   161  
   162  	t.Logf("tryTotal = %d", tryTotal)
   163  	total += tryTotal
   164  
   165  	if v != total {
   166  		t.Fatalf("Bad count: got %v, want %v", v, total)
   167  	}
   168  }
   169  
   170  // BenchmarkMax is equivalent to TestMutualExclusion, with the following
   171  // differences:
   172  //
   173  // - The number of goroutines is variable, with the maximum value depending on
   174  // GOMAXPROCS.
   175  //
   176  // - The number of iterations per benchmark is controlled by the benchmarking
   177  // framework.
   178  //
   179  // - Care is taken to ensure that all goroutines participating in the benchmark
   180  // have been created before the benchmark begins.
   181  func BenchmarkMax(b *testing.B) {
   182  	for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
   183  		b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
   184  			m := NewMax(1)
   185  
   186  			var ready sync.WaitGroup
   187  			begin := make(chan struct{})
   188  			var end sync.WaitGroup
   189  			for i := 0; i < n; i++ {
   190  				ready.Add(1)
   191  				end.Add(1)
   192  				go func() {
   193  					ready.Done()
   194  					<-begin
   195  					for j := 0; j < b.N; j++ {
   196  						m.Enter(context.Background())
   197  						m.Exit()
   198  					}
   199  					end.Done()
   200  				}()
   201  			}
   202  
   203  			ready.Wait()
   204  			b.ResetTimer()
   205  			close(begin)
   206  			end.Wait()
   207  		})
   208  	}
   209  }
   210  
   211  func BenchmarkSyncMutex(b *testing.B) {
   212  	for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
   213  		b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
   214  			var m sync.Mutex
   215  
   216  			var ready sync.WaitGroup
   217  			begin := make(chan struct{})
   218  			var end sync.WaitGroup
   219  			for i := 0; i < n; i++ {
   220  				ready.Add(1)
   221  				end.Add(1)
   222  				go func() {
   223  					ready.Done()
   224  					<-begin
   225  					for j := 0; j < b.N; j++ {
   226  						m.Lock()
   227  						m.Unlock()
   228  					}
   229  					end.Done()
   230  				}()
   231  			}
   232  
   233  			ready.Wait()
   234  			b.ResetTimer()
   235  			close(begin)
   236  			end.Wait()
   237  		})
   238  	}
   239  }
   240  
   241  func BenchmarkChan(b *testing.B) {
   242  	for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
   243  		b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
   244  			ch := make(chan struct{}, 1)
   245  
   246  			var ready sync.WaitGroup
   247  			begin := make(chan struct{})
   248  			var end sync.WaitGroup
   249  			for i := 0; i < n; i++ {
   250  				ready.Add(1)
   251  				end.Add(1)
   252  				go func() {
   253  					ready.Done()
   254  					<-begin
   255  					for j := 0; j < b.N; j++ {
   256  						ch <- struct{}{}
   257  						<-ch
   258  					}
   259  					end.Done()
   260  				}()
   261  			}
   262  
   263  			ready.Wait()
   264  			b.ResetTimer()
   265  			close(begin)
   266  			end.Wait()
   267  		})
   268  	}
   269  }
   270  
   271  func BenchmarkSemaphore(b *testing.B) {
   272  	for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
   273  		b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
   274  			sema := semaphore.NewWeighted(1)
   275  
   276  			var ready sync.WaitGroup
   277  			begin := make(chan struct{})
   278  			var end sync.WaitGroup
   279  			for i := 0; i < n; i++ {
   280  				ready.Add(1)
   281  				end.Add(1)
   282  				go func() {
   283  					ready.Done()
   284  					<-begin
   285  					for j := 0; j < b.N; j++ {
   286  						sema.Acquire(context.Background(), 1)
   287  						sema.Release(1)
   288  					}
   289  					end.Done()
   290  				}()
   291  			}
   292  
   293  			ready.Wait()
   294  			b.ResetTimer()
   295  			close(begin)
   296  			end.Wait()
   297  		})
   298  	}
   299  }