github.com/cilium/cilium@v1.16.2/pkg/lock/stoppable_waitgroup_test.go (about) 1 // SPDX-License-Identifier: Apache-2.0 2 // Copyright Authors of Cilium 3 4 package lock 5 6 import ( 7 "math/rand/v2" 8 "sync" 9 "sync/atomic" 10 "testing" 11 "time" 12 13 "github.com/stretchr/testify/require" 14 ) 15 16 func TestAdd(t *testing.T) { 17 l := NewStoppableWaitGroup() 18 19 l.Add() 20 require.Equal(t, int64(1), l.i.Load()) 21 l.Add() 22 require.Equal(t, int64(2), l.i.Load()) 23 close(l.noopAdd) 24 l.Add() 25 require.Equal(t, int64(2), l.i.Load()) 26 } 27 28 func TestDone(t *testing.T) { 29 l := NewStoppableWaitGroup() 30 31 l.i.Store(4) 32 l.Done() 33 require.Equal(t, int64(3), l.i.Load()) 34 l.Done() 35 require.Equal(t, int64(2), l.i.Load()) 36 close(l.noopAdd) 37 select { 38 case _, ok := <-l.noopDone: 39 // channel should not have been closed 40 require.True(t, ok) 41 default: 42 } 43 44 l.Done() 45 require.Equal(t, int64(1), l.i.Load()) 46 select { 47 case _, ok := <-l.noopDone: 48 // channel should not have been closed 49 require.True(t, ok) 50 default: 51 } 52 53 l.Done() 54 require.Equal(t, int64(0), l.i.Load()) 55 select { 56 case _, ok := <-l.noopDone: 57 require.False(t, ok) 58 default: 59 // channel should have been closed 60 require.True(t, false) 61 } 62 63 l.Done() 64 require.Equal(t, int64(0), l.i.Load()) 65 } 66 67 func TestStop(t *testing.T) { 68 l := NewStoppableWaitGroup() 69 70 l.Add() 71 require.Equal(t, int64(1), l.i.Load()) 72 l.Add() 73 require.Equal(t, int64(2), l.i.Load()) 74 l.Stop() 75 l.Add() 76 require.Equal(t, int64(2), l.i.Load()) 77 } 78 79 func TestWait(t *testing.T) { 80 l := NewStoppableWaitGroup() 81 82 waitClosed := make(chan struct{}) 83 go func() { 84 l.Wait() 85 close(waitClosed) 86 }() 87 88 l.Add() 89 require.Equal(t, int64(1), l.i.Load()) 90 l.Add() 91 require.Equal(t, int64(2), l.i.Load()) 92 l.Stop() 93 l.Add() 94 require.Equal(t, int64(2), l.i.Load()) 95 96 l.Done() 97 require.Equal(t, int64(1), l.i.Load()) 98 select { 99 case _, ok := <-waitClosed: 100 // channel should not have been closed 101 require.True(t, ok) 102 default: 103 } 104 105 l.Done() 106 require.Equal(t, int64(0), l.i.Load()) 107 select { 108 case _, ok := <-waitClosed: 109 // channel should have been closed 110 require.False(t, ok) 111 default: 112 } 113 114 l.Done() 115 require.Equal(t, int64(0), l.i.Load()) 116 } 117 118 func TestWaitChannel(t *testing.T) { 119 l := NewStoppableWaitGroup() 120 121 l.Add() 122 require.Equal(t, int64(1), l.i.Load()) 123 l.Add() 124 require.Equal(t, int64(2), l.i.Load()) 125 l.Stop() 126 l.Add() 127 require.Equal(t, int64(2), l.i.Load()) 128 129 l.Done() 130 require.Equal(t, int64(1), l.i.Load()) 131 select { 132 case _, ok := <-l.WaitChannel(): 133 // channel should not have been closed 134 require.True(t, ok) 135 default: 136 } 137 138 l.Done() 139 require.Equal(t, int64(0), l.i.Load()) 140 select { 141 case _, ok := <-l.WaitChannel(): 142 // channel should have been closed 143 require.False(t, ok) 144 default: 145 } 146 147 l.Done() 148 require.Equal(t, int64(0), l.i.Load()) 149 } 150 151 func TestParallelism(t *testing.T) { 152 l := NewStoppableWaitGroup() 153 154 in := make(chan int) 155 stop := make(chan struct{}) 156 go func() { 157 for { 158 select { 159 case in <- rand.IntN(1 - 0): 160 case <-stop: 161 close(in) 162 return 163 } 164 } 165 }() 166 var adds atomic.Int64 167 var wg sync.WaitGroup 168 wg.Add(10) 169 for i := 0; i < 10; i++ { 170 go func() { 171 defer wg.Done() 172 for a := range in { 173 if a == 0 { 174 adds.Add(1) 175 l.Add() 176 } else { 177 l.Done() 178 adds.Add(-1) 179 } 180 } 181 }() 182 } 183 184 time.Sleep(time.Duration(rand.IntN(3-0)) * time.Second) 185 close(stop) 186 wg.Wait() 187 for add := adds.Load(); add != 0; add = adds.Load() { 188 switch { 189 case add < 0: 190 adds.Add(1) 191 l.Add() 192 case add > 0: 193 l.Done() 194 adds.Add(-1) 195 } 196 } 197 l.Stop() 198 l.Wait() 199 }