k8s.io/apimachinery@v0.29.2/pkg/util/waitgroup/ratelimited_waitgroup_test.go (about) 1 /* 2 Copyright 2023 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package waitgroup 18 19 import ( 20 "context" 21 "strings" 22 "sync" 23 "testing" 24 "time" 25 26 "golang.org/x/time/rate" 27 "k8s.io/apimachinery/pkg/util/wait" 28 ) 29 30 func TestRateLimitedSafeWaitGroup(t *testing.T) { 31 // we want to keep track of how many times rate limiter Wait method is 32 // being invoked, both before and after the wait group is in waiting mode. 33 limiter := &limiterWrapper{} 34 35 // we expect the context passed by the factory to be used 36 var cancelInvoked int 37 factory := &factory{ 38 limiter: limiter, 39 grace: 2 * time.Second, 40 ctx: context.Background(), 41 cancel: func() { 42 cancelInvoked++ 43 }, 44 } 45 target := &rateLimitedSafeWaitGroupWrapper{ 46 RateLimitedSafeWaitGroup: &RateLimitedSafeWaitGroup{limiter: limiter}, 47 } 48 49 // two set of requests 50 // - n1: this set will finish using this waitgroup before Wait is invoked 51 // - n2: this set will be in flight after Wait is invoked 52 n1, n2 := 100, 101 53 54 // so we know when all requests in n1 are done using the waitgroup 55 n1DoneWG := sync.WaitGroup{} 56 57 // so we know when all requests in n2 have called Add, 58 // but not finished with the waitgroup yet. 59 // this will allow the test to invoke 'Wait' once all requests 60 // in n2 have called `Add`, but none has called `Done` yet. 61 n2BeforeWaitWG := sync.WaitGroup{} 62 // so we know when all requests in n2 have called Done and 63 // are finished using the waitgroup 64 n2DoneWG := sync.WaitGroup{} 65 66 startCh, blockedCh := make(chan struct{}), make(chan struct{}) 67 n1DoneWG.Add(n1) 68 for i := 0; i < n1; i++ { 69 go func() { 70 defer n1DoneWG.Done() 71 <-startCh 72 73 target.Add(1) 74 // let's finish using the waitgroup immediately 75 target.Done() 76 }() 77 } 78 79 n2BeforeWaitWG.Add(n2) 80 n2DoneWG.Add(n2) 81 for i := 0; i < n2; i++ { 82 go func() { 83 func() { 84 defer n2BeforeWaitWG.Done() 85 <-startCh 86 87 target.Add(1) 88 }() 89 90 func() { 91 defer n2DoneWG.Done() 92 // let's wait for the test to instruct the requests in n2 93 // that it is time to finish using the waitgroup. 94 <-blockedCh 95 96 target.Done() 97 }() 98 }() 99 } 100 101 // initially the count should be zero 102 if count := target.Count(); count != 0 { 103 t.Errorf("expected count to be zero, but got: %d", count) 104 } 105 // start the test 106 close(startCh) 107 // wait for the first set of requests (n1) to be done 108 n1DoneWG.Wait() 109 110 // after the first set of requests (n1) are done, the count should be zero 111 if invoked := limiter.invoked(); invoked != 0 { 112 t.Errorf("expected no call to rate limiter before Wait is called, but got: %d", invoked) 113 } 114 115 // make sure all requetss in the second group (n2) have started using the 116 // waitgroup (Add invoked) but no request is done using the waitgroup yet. 117 n2BeforeWaitWG.Wait() 118 119 // count should be n2, since every request in n2 is still using the waitgroup 120 if count := target.Count(); count != n2 { 121 t.Errorf("expected count to be: %d, but got: %d", n2, count) 122 } 123 124 // time for us to mark the waitgroup as `Waiting` 125 waitDoneCh := make(chan waitResult) 126 go func() { 127 factory.grace = 2 * time.Second 128 before, after, err := target.Wait(factory.NewRateLimiter) 129 waitDoneCh <- waitResult{before: before, after: after, err: err} 130 }() 131 132 // make sure there is no flake in the test due to this race condition 133 var waitingGot bool 134 wait.PollImmediate(500*time.Millisecond, wait.ForeverTestTimeout, func() (done bool, err error) { 135 if waiting := target.Waiting(); waiting { 136 waitingGot = true 137 return true, nil 138 } 139 return false, nil 140 }) 141 // verify that the waitgroup is in 'Waiting' mode 142 if !waitingGot { 143 t.Errorf("expected to be in waiting") 144 } 145 146 // we should not allow any new request to use this waitgroup any longer 147 if err := target.Add(1); err == nil || 148 !strings.Contains(err.Error(), "add with positive delta after Wait is forbidden") { 149 t.Errorf("expected Add to return error while in waiting mode: %v", err) 150 } 151 152 // make sure that RateLimitedSafeWaitGroup passes the right 153 // request count to the limiter factory. 154 if factory.countGot != n2 { 155 t.Errorf("expected count passed to factory to be: %d, but got: %d", n2, factory.countGot) 156 } 157 158 // indicate to all requests (each request in n2) that are 159 // currently using this waitgroup that they can go ahead 160 // and invoke 'Done' to finish using this waitgroup. 161 close(blockedCh) 162 n2DoneWG.Wait() 163 164 if invoked := limiter.invoked(); invoked != n2 { 165 t.Errorf("expected rate limiter to be called %d times, but got: %d", n2, invoked) 166 } 167 168 waitResult := <-waitDoneCh 169 if count := target.Count(); count != 0 { 170 t.Errorf("expected count to be zero, but got: %d", count) 171 } 172 if waitResult.before != n2 { 173 t.Errorf("expected count before Wait to be: %d, but got: %d", n2, waitResult.before) 174 } 175 if waitResult.after != 0 { 176 t.Errorf("expected count after Wait to be zero, but got: %d", waitResult.after) 177 } 178 if cancelInvoked != 1 { 179 t.Errorf("expected context cancel to be invoked once, but got: %d", cancelInvoked) 180 } 181 } 182 183 func TestRateLimitedSafeWaitGroupWithHardTimeout(t *testing.T) { 184 target := &rateLimitedSafeWaitGroupWrapper{ 185 RateLimitedSafeWaitGroup: &RateLimitedSafeWaitGroup{}, 186 } 187 n := 10 188 wg := sync.WaitGroup{} 189 wg.Add(n) 190 for i := 0; i < n; i++ { 191 go func() { 192 defer wg.Done() 193 target.Add(1) 194 }() 195 } 196 197 wg.Wait() 198 if count := target.Count(); count != n { 199 t.Errorf("expected count to be: %d, but got: %d", n, count) 200 } 201 202 ctx, cancel := context.WithCancel(context.Background()) 203 cancel() 204 activeAt, activeNow, err := target.Wait(func(count int) (RateLimiter, context.Context, context.CancelFunc) { 205 return nil, ctx, cancel 206 }) 207 if activeAt != n { 208 t.Errorf("expected active at Wait to be: %d, but got: %d", n, activeAt) 209 } 210 if activeNow != n { 211 t.Errorf("expected active after Wait to be: %d, but got: %d", n, activeNow) 212 } 213 if err != context.Canceled { 214 t.Errorf("expected error: %v, but got: %v", context.Canceled, err) 215 } 216 } 217 218 func TestRateLimitedSafeWaitGroupWithBurstOfOne(t *testing.T) { 219 target := &rateLimitedSafeWaitGroupWrapper{ 220 RateLimitedSafeWaitGroup: &RateLimitedSafeWaitGroup{}, 221 } 222 n := 200 223 grace := 5 * time.Second 224 wg := sync.WaitGroup{} 225 wg.Add(n) 226 for i := 0; i < n; i++ { 227 go func() { 228 defer wg.Done() 229 target.Add(1) 230 }() 231 } 232 wg.Wait() 233 234 waitingCh := make(chan struct{}) 235 wg.Add(n) 236 for i := 0; i < n; i++ { 237 go func() { 238 defer wg.Done() 239 240 <-waitingCh 241 target.Done() 242 }() 243 } 244 defer wg.Wait() 245 246 now := time.Now() 247 t.Logf("Wait starting, N=%d, grace: %s, at: %s", n, grace, now) 248 activeAt, activeNow, err := target.Wait(func(count int) (RateLimiter, context.Context, context.CancelFunc) { 249 defer close(waitingCh) 250 // no deadline in context, Wait will wait forever, we want to measure 251 // how long it takes for the requests to drain. 252 return rate.NewLimiter(rate.Limit(n/int(grace.Seconds())), 1), context.Background(), func() {} 253 }) 254 took := time.Since(now) 255 t.Logf("Wait finished, count(before): %d, count(after): %d, took: %s, err: %v", activeAt, activeNow, took, err) 256 257 // in CPU starved environment, the go routines may not finish in time 258 if took > 2*grace { 259 t.Errorf("expected Wait to take: %s, but it took: %s", grace, took) 260 } 261 } 262 263 type waitResult struct { 264 before, after int 265 err error 266 } 267 268 type rateLimitedSafeWaitGroupWrapper struct { 269 *RateLimitedSafeWaitGroup 270 } 271 272 // used by test only 273 func (wg *rateLimitedSafeWaitGroupWrapper) Count() int { 274 wg.mu.Lock() 275 defer wg.mu.Unlock() 276 277 return wg.count 278 } 279 func (wg *rateLimitedSafeWaitGroupWrapper) Waiting() bool { 280 wg.mu.Lock() 281 defer wg.mu.Unlock() 282 283 return wg.wait 284 } 285 286 type limiterWrapper struct { 287 delegate RateLimiter 288 lock sync.Mutex 289 invokedN int 290 } 291 292 func (w *limiterWrapper) invoked() int { 293 w.lock.Lock() 294 defer w.lock.Unlock() 295 return w.invokedN 296 } 297 func (w *limiterWrapper) Wait(ctx context.Context) error { 298 w.lock.Lock() 299 w.invokedN++ 300 w.lock.Unlock() 301 302 if w.delegate != nil { 303 w.delegate.Wait(ctx) 304 } 305 return nil 306 } 307 308 type factory struct { 309 limiter *limiterWrapper 310 grace time.Duration 311 ctx context.Context 312 cancel context.CancelFunc 313 countGot int 314 } 315 316 func (f *factory) NewRateLimiter(count int) (RateLimiter, context.Context, context.CancelFunc) { 317 f.countGot = count 318 f.limiter.delegate = rate.NewLimiter(rate.Limit(count/int(f.grace.Seconds())), 20) 319 return f.limiter, f.ctx, f.cancel 320 }