github.com/puzpuzpuz/xsync/v3@v3.1.1-0.20240225193106-cbe4ec1e954f/mpmcqueue_test.go (about)

     1  // Copyright notice. The following tests are partially based on
     2  // the following file from the Go Programming Language core repo:
     3  // https://github.com/golang/go/blob/831f9376d8d730b16fb33dfd775618dffe13ce7a/src/runtime/chan_test.go
     4  
     5  package xsync_test
     6  
     7  import (
     8  	"runtime"
     9  	"sync"
    10  	"sync/atomic"
    11  	"testing"
    12  	"time"
    13  
    14  	. "github.com/puzpuzpuz/xsync/v3"
    15  )
    16  
    17  func TestQueue_InvalidSize(t *testing.T) {
    18  	defer func() { recover() }()
    19  	NewMPMCQueue(0)
    20  	t.Fatal("no panic detected")
    21  }
    22  
    23  func TestQueueEnqueueDequeue(t *testing.T) {
    24  	q := NewMPMCQueue(10)
    25  	for i := 0; i < 10; i++ {
    26  		q.Enqueue(i)
    27  	}
    28  	for i := 0; i < 10; i++ {
    29  		if got := q.Dequeue(); got != i {
    30  			t.Fatalf("got %v, want %d", got, i)
    31  		}
    32  	}
    33  }
    34  
    35  func TestQueueEnqueueBlocksOnFull(t *testing.T) {
    36  	q := NewMPMCQueue(1)
    37  	q.Enqueue("foo")
    38  	cdone := make(chan bool)
    39  	flag := int32(0)
    40  	go func() {
    41  		q.Enqueue("bar")
    42  		if atomic.LoadInt32(&flag) == 0 {
    43  			t.Error("enqueue on full queue didn't wait for dequeue")
    44  		}
    45  		cdone <- true
    46  	}()
    47  	time.Sleep(50 * time.Millisecond)
    48  	atomic.StoreInt32(&flag, 1)
    49  	if got := q.Dequeue(); got != "foo" {
    50  		t.Fatalf("got %v, want foo", got)
    51  	}
    52  	<-cdone
    53  }
    54  
    55  func TestQueueDequeueBlocksOnEmpty(t *testing.T) {
    56  	q := NewMPMCQueue(2)
    57  	cdone := make(chan bool)
    58  	flag := int32(0)
    59  	go func() {
    60  		q.Dequeue()
    61  		if atomic.LoadInt32(&flag) == 0 {
    62  			t.Error("dequeue on empty queue didn't wait for enqueue")
    63  		}
    64  		cdone <- true
    65  	}()
    66  	time.Sleep(50 * time.Millisecond)
    67  	atomic.StoreInt32(&flag, 1)
    68  	q.Enqueue("foobar")
    69  	<-cdone
    70  }
    71  
    72  func TestQueueTryEnqueueDequeue(t *testing.T) {
    73  	q := NewMPMCQueue(10)
    74  	for i := 0; i < 10; i++ {
    75  		if !q.TryEnqueue(i) {
    76  			t.Fatalf("failed to enqueue for %d", i)
    77  		}
    78  	}
    79  	for i := 0; i < 10; i++ {
    80  		if got, ok := q.TryDequeue(); !ok || got != i {
    81  			t.Fatalf("got %v, want %d, for status %v", got, i, ok)
    82  		}
    83  	}
    84  }
    85  
    86  func TestQueueTryEnqueueOnFull(t *testing.T) {
    87  	q := NewMPMCQueue(1)
    88  	if !q.TryEnqueue("foo") {
    89  		t.Error("failed to enqueue initial item")
    90  	}
    91  	if q.TryEnqueue("bar") {
    92  		t.Error("got success for enqueue on full queue")
    93  	}
    94  }
    95  
    96  func TestQueueTryDequeueBlocksOnEmpty(t *testing.T) {
    97  	q := NewMPMCQueue(2)
    98  	if _, ok := q.TryDequeue(); ok {
    99  		t.Error("got success for enqueue on empty queue")
   100  	}
   101  }
   102  
   103  func hammerQueueBlockingCalls(t *testing.T, gomaxprocs, numOps, numThreads int) {
   104  	runtime.GOMAXPROCS(gomaxprocs)
   105  	q := NewMPMCQueue(numThreads)
   106  	startwg := sync.WaitGroup{}
   107  	startwg.Add(1)
   108  	csum := make(chan int, numThreads)
   109  	// Start producers.
   110  	for i := 0; i < numThreads; i++ {
   111  		go func(n int) {
   112  			startwg.Wait()
   113  			for j := n; j < numOps; j += numThreads {
   114  				q.Enqueue(j)
   115  			}
   116  		}(i)
   117  	}
   118  	// Start consumers.
   119  	for i := 0; i < numThreads; i++ {
   120  		go func(n int) {
   121  			startwg.Wait()
   122  			sum := 0
   123  			for j := n; j < numOps; j += numThreads {
   124  				item := q.Dequeue()
   125  				sum += item.(int)
   126  			}
   127  			csum <- sum
   128  		}(i)
   129  	}
   130  	startwg.Done()
   131  	// Wait for all the sums from producers.
   132  	sum := 0
   133  	for i := 0; i < numThreads; i++ {
   134  		s := <-csum
   135  		sum += s
   136  	}
   137  	// Assert the total sum.
   138  	expectedSum := numOps * (numOps - 1) / 2
   139  	if sum != expectedSum {
   140  		t.Fatalf("sums don't match for %d num ops, %d num threads: got %d, want %d",
   141  			numOps, numThreads, sum, expectedSum)
   142  	}
   143  }
   144  
   145  func TestQueueBlockingCalls(t *testing.T) {
   146  	defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1))
   147  	n := 100
   148  	if testing.Short() {
   149  		n = 10
   150  	}
   151  	hammerQueueBlockingCalls(t, 1, 100*n, n)
   152  	hammerQueueBlockingCalls(t, 1, 1000*n, 10*n)
   153  	hammerQueueBlockingCalls(t, 4, 100*n, n)
   154  	hammerQueueBlockingCalls(t, 4, 1000*n, 10*n)
   155  	hammerQueueBlockingCalls(t, 8, 100*n, n)
   156  	hammerQueueBlockingCalls(t, 8, 1000*n, 10*n)
   157  }
   158  
   159  func hammerQueueNonBlockingCalls(t *testing.T, gomaxprocs, numOps, numThreads int) {
   160  	runtime.GOMAXPROCS(gomaxprocs)
   161  	q := NewMPMCQueue(numThreads)
   162  	startwg := sync.WaitGroup{}
   163  	startwg.Add(1)
   164  	csum := make(chan int, numThreads)
   165  	// Start producers.
   166  	for i := 0; i < numThreads; i++ {
   167  		go func(n int) {
   168  			startwg.Wait()
   169  			for j := n; j < numOps; j += numThreads {
   170  				for !q.TryEnqueue(j) {
   171  					// busy spin until success
   172  				}
   173  			}
   174  		}(i)
   175  	}
   176  	// Start consumers.
   177  	for i := 0; i < numThreads; i++ {
   178  		go func(n int) {
   179  			startwg.Wait()
   180  			sum := 0
   181  			for j := n; j < numOps; j += numThreads {
   182  				var (
   183  					item interface{}
   184  					ok   bool
   185  				)
   186  				for {
   187  					// busy spin until success
   188  					if item, ok = q.TryDequeue(); ok {
   189  						sum += item.(int)
   190  						break
   191  					}
   192  				}
   193  			}
   194  			csum <- sum
   195  		}(i)
   196  	}
   197  	startwg.Done()
   198  	// Wait for all the sums from producers.
   199  	sum := 0
   200  	for i := 0; i < numThreads; i++ {
   201  		s := <-csum
   202  		sum += s
   203  	}
   204  	// Assert the total sum.
   205  	expectedSum := numOps * (numOps - 1) / 2
   206  	if sum != expectedSum {
   207  		t.Fatalf("sums don't match for %d num ops, %d num threads: got %d, want %d",
   208  			numOps, numThreads, sum, expectedSum)
   209  	}
   210  }
   211  
   212  func TestQueueNonBlockingCalls(t *testing.T) {
   213  	defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1))
   214  	n := 10
   215  	if testing.Short() {
   216  		n = 1
   217  	}
   218  	hammerQueueNonBlockingCalls(t, 1, n, n)
   219  	hammerQueueNonBlockingCalls(t, 2, 10*n, 2*n)
   220  	hammerQueueNonBlockingCalls(t, 4, 100*n, 4*n)
   221  }
   222  
   223  func benchmarkQueueProdCons(b *testing.B, queueSize, localWork int) {
   224  	callsPerSched := queueSize
   225  	procs := runtime.GOMAXPROCS(-1) / 2
   226  	if procs == 0 {
   227  		procs = 1
   228  	}
   229  	N := int32(b.N / callsPerSched)
   230  	c := make(chan bool, 2*procs)
   231  	q := NewMPMCQueue(queueSize)
   232  	for p := 0; p < procs; p++ {
   233  		go func() {
   234  			foo := 0
   235  			for atomic.AddInt32(&N, -1) >= 0 {
   236  				for g := 0; g < callsPerSched; g++ {
   237  					for i := 0; i < localWork; i++ {
   238  						foo *= 2
   239  						foo /= 2
   240  					}
   241  					q.Enqueue(1)
   242  				}
   243  			}
   244  			q.Enqueue(0)
   245  			c <- foo == 42
   246  		}()
   247  		go func() {
   248  			foo := 0
   249  			for {
   250  				v := q.Dequeue().(int)
   251  				if v == 0 {
   252  					break
   253  				}
   254  				for i := 0; i < localWork; i++ {
   255  					foo *= 2
   256  					foo /= 2
   257  				}
   258  			}
   259  			c <- foo == 42
   260  		}()
   261  	}
   262  	for p := 0; p < procs; p++ {
   263  		<-c
   264  		<-c
   265  	}
   266  }
   267  
   268  func BenchmarkQueueProdCons(b *testing.B) {
   269  	benchmarkQueueProdCons(b, 1000, 0)
   270  }
   271  
   272  func BenchmarkQueueProdConsWork100(b *testing.B) {
   273  	benchmarkQueueProdCons(b, 1000, 100)
   274  }
   275  
   276  func benchmarkChanProdCons(b *testing.B, chanSize, localWork int) {
   277  	callsPerSched := chanSize
   278  	procs := runtime.GOMAXPROCS(-1) / 2
   279  	if procs == 0 {
   280  		procs = 1
   281  	}
   282  	N := int32(b.N / callsPerSched)
   283  	c := make(chan bool, 2*procs)
   284  	myc := make(chan int, chanSize)
   285  	for p := 0; p < procs; p++ {
   286  		go func() {
   287  			foo := 0
   288  			for atomic.AddInt32(&N, -1) >= 0 {
   289  				for g := 0; g < callsPerSched; g++ {
   290  					for i := 0; i < localWork; i++ {
   291  						foo *= 2
   292  						foo /= 2
   293  					}
   294  					myc <- 1
   295  				}
   296  			}
   297  			myc <- 0
   298  			c <- foo == 42
   299  		}()
   300  		go func() {
   301  			foo := 0
   302  			for {
   303  				v := <-myc
   304  				if v == 0 {
   305  					break
   306  				}
   307  				for i := 0; i < localWork; i++ {
   308  					foo *= 2
   309  					foo /= 2
   310  				}
   311  			}
   312  			c <- foo == 42
   313  		}()
   314  	}
   315  	for p := 0; p < procs; p++ {
   316  		<-c
   317  		<-c
   318  	}
   319  }
   320  
   321  func BenchmarkChanProdCons(b *testing.B) {
   322  	benchmarkChanProdCons(b, 1000, 0)
   323  }
   324  
   325  func BenchmarkChanProdConsWork100(b *testing.B) {
   326  	benchmarkChanProdCons(b, 1000, 100)
   327  }