github.com/cilium/cilium@v1.16.2/pkg/lock/stoppable_waitgroup_test.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package lock
     5  
     6  import (
     7  	"math/rand/v2"
     8  	"sync"
     9  	"sync/atomic"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  func TestAdd(t *testing.T) {
    17  	l := NewStoppableWaitGroup()
    18  
    19  	l.Add()
    20  	require.Equal(t, int64(1), l.i.Load())
    21  	l.Add()
    22  	require.Equal(t, int64(2), l.i.Load())
    23  	close(l.noopAdd)
    24  	l.Add()
    25  	require.Equal(t, int64(2), l.i.Load())
    26  }
    27  
    28  func TestDone(t *testing.T) {
    29  	l := NewStoppableWaitGroup()
    30  
    31  	l.i.Store(4)
    32  	l.Done()
    33  	require.Equal(t, int64(3), l.i.Load())
    34  	l.Done()
    35  	require.Equal(t, int64(2), l.i.Load())
    36  	close(l.noopAdd)
    37  	select {
    38  	case _, ok := <-l.noopDone:
    39  		// channel should not have been closed
    40  		require.True(t, ok)
    41  	default:
    42  	}
    43  
    44  	l.Done()
    45  	require.Equal(t, int64(1), l.i.Load())
    46  	select {
    47  	case _, ok := <-l.noopDone:
    48  		// channel should not have been closed
    49  		require.True(t, ok)
    50  	default:
    51  	}
    52  
    53  	l.Done()
    54  	require.Equal(t, int64(0), l.i.Load())
    55  	select {
    56  	case _, ok := <-l.noopDone:
    57  		require.False(t, ok)
    58  	default:
    59  		// channel should have been closed
    60  		require.True(t, false)
    61  	}
    62  
    63  	l.Done()
    64  	require.Equal(t, int64(0), l.i.Load())
    65  }
    66  
    67  func TestStop(t *testing.T) {
    68  	l := NewStoppableWaitGroup()
    69  
    70  	l.Add()
    71  	require.Equal(t, int64(1), l.i.Load())
    72  	l.Add()
    73  	require.Equal(t, int64(2), l.i.Load())
    74  	l.Stop()
    75  	l.Add()
    76  	require.Equal(t, int64(2), l.i.Load())
    77  }
    78  
    79  func TestWait(t *testing.T) {
    80  	l := NewStoppableWaitGroup()
    81  
    82  	waitClosed := make(chan struct{})
    83  	go func() {
    84  		l.Wait()
    85  		close(waitClosed)
    86  	}()
    87  
    88  	l.Add()
    89  	require.Equal(t, int64(1), l.i.Load())
    90  	l.Add()
    91  	require.Equal(t, int64(2), l.i.Load())
    92  	l.Stop()
    93  	l.Add()
    94  	require.Equal(t, int64(2), l.i.Load())
    95  
    96  	l.Done()
    97  	require.Equal(t, int64(1), l.i.Load())
    98  	select {
    99  	case _, ok := <-waitClosed:
   100  		// channel should not have been closed
   101  		require.True(t, ok)
   102  	default:
   103  	}
   104  
   105  	l.Done()
   106  	require.Equal(t, int64(0), l.i.Load())
   107  	select {
   108  	case _, ok := <-waitClosed:
   109  		// channel should have been closed
   110  		require.False(t, ok)
   111  	default:
   112  	}
   113  
   114  	l.Done()
   115  	require.Equal(t, int64(0), l.i.Load())
   116  }
   117  
   118  func TestWaitChannel(t *testing.T) {
   119  	l := NewStoppableWaitGroup()
   120  
   121  	l.Add()
   122  	require.Equal(t, int64(1), l.i.Load())
   123  	l.Add()
   124  	require.Equal(t, int64(2), l.i.Load())
   125  	l.Stop()
   126  	l.Add()
   127  	require.Equal(t, int64(2), l.i.Load())
   128  
   129  	l.Done()
   130  	require.Equal(t, int64(1), l.i.Load())
   131  	select {
   132  	case _, ok := <-l.WaitChannel():
   133  		// channel should not have been closed
   134  		require.True(t, ok)
   135  	default:
   136  	}
   137  
   138  	l.Done()
   139  	require.Equal(t, int64(0), l.i.Load())
   140  	select {
   141  	case _, ok := <-l.WaitChannel():
   142  		// channel should have been closed
   143  		require.False(t, ok)
   144  	default:
   145  	}
   146  
   147  	l.Done()
   148  	require.Equal(t, int64(0), l.i.Load())
   149  }
   150  
   151  func TestParallelism(t *testing.T) {
   152  	l := NewStoppableWaitGroup()
   153  
   154  	in := make(chan int)
   155  	stop := make(chan struct{})
   156  	go func() {
   157  		for {
   158  			select {
   159  			case in <- rand.IntN(1 - 0):
   160  			case <-stop:
   161  				close(in)
   162  				return
   163  			}
   164  		}
   165  	}()
   166  	var adds atomic.Int64
   167  	var wg sync.WaitGroup
   168  	wg.Add(10)
   169  	for i := 0; i < 10; i++ {
   170  		go func() {
   171  			defer wg.Done()
   172  			for a := range in {
   173  				if a == 0 {
   174  					adds.Add(1)
   175  					l.Add()
   176  				} else {
   177  					l.Done()
   178  					adds.Add(-1)
   179  				}
   180  			}
   181  		}()
   182  	}
   183  
   184  	time.Sleep(time.Duration(rand.IntN(3-0)) * time.Second)
   185  	close(stop)
   186  	wg.Wait()
   187  	for add := adds.Load(); add != 0; add = adds.Load() {
   188  		switch {
   189  		case add < 0:
   190  			adds.Add(1)
   191  			l.Add()
   192  		case add > 0:
   193  			l.Done()
   194  			adds.Add(-1)
   195  		}
   196  	}
   197  	l.Stop()
   198  	l.Wait()
   199  }