github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kbfs/kbfssync/semaphore_test.go (about) 1 // Copyright 2017 Keybase Inc. All rights reserved. 2 // Use of this source code is governed by a BSD 3 // license that can be found in the LICENSE file. 4 5 package kbfssync 6 7 import ( 8 "context" 9 "math" 10 "testing" 11 "time" 12 13 "github.com/pkg/errors" 14 "github.com/stretchr/testify/require" 15 ) 16 17 var testTimeout = 10 * time.Second 18 19 type acquireCall struct { 20 n int64 21 count int64 22 err error 23 } 24 25 func callAcquire(ctx context.Context, s *Semaphore, n int64) acquireCall { 26 count, err := s.Acquire(ctx, n) 27 return acquireCall{n, count, err} 28 } 29 30 // requireNoCall checks that there is nothing to read from 31 // callCh. This is a racy check since it doesn't distinguish between 32 // the goroutine with the call not having run yet, and the goroutine 33 // with the call having run but being blocked on the semaphore. 34 func requireNoCall(t *testing.T, callCh <-chan acquireCall) { 35 select { 36 case call := <-callCh: 37 t.Fatalf("Unexpected call: %+v", call) 38 default: 39 } 40 } 41 42 // TestSimple tests that Acquire and Release work in a simple 43 // two-goroutine scenario. 44 func TestSimple(t *testing.T) { 45 ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 46 defer cancel() 47 48 var n int64 = 10 49 50 s := NewSemaphore() 51 require.Equal(t, int64(0), s.Count()) 52 53 callCh := make(chan acquireCall, 1) 54 go func() { 55 callCh <- callAcquire(ctx, s, n) 56 }() 57 58 requireNoCall(t, callCh) 59 60 count := s.Release(n - 1) 61 require.Equal(t, n-1, count) 62 require.Equal(t, n-1, s.Count()) 63 64 requireNoCall(t, callCh) 65 66 count = s.Release(1) 67 require.Equal(t, n, count) 68 69 select { 70 case call := <-callCh: 71 require.Equal(t, acquireCall{n, 0, nil}, call) 72 case <-ctx.Done(): 73 t.Fatal(ctx.Err()) 74 } 75 76 require.Equal(t, int64(0), s.Count()) 77 } 78 79 // TestForceAcquire tests that ForceAcquire works in a simple two-goroutine 80 // scenario. 81 func TestForceAcquire(t *testing.T) { 82 ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 83 defer cancel() 84 85 var n int64 = 10 86 87 s := NewSemaphore() 88 require.Equal(t, int64(0), s.Count()) 89 90 callCh := make(chan acquireCall, 1) 91 go func() { 92 callCh <- callAcquire(ctx, s, n) 93 }() 94 95 requireNoCall(t, callCh) 96 97 count := s.Release(n - 1) 98 require.Equal(t, n-1, count) 99 require.Equal(t, n-1, s.Count()) 100 101 requireNoCall(t, callCh) 102 103 count = s.ForceAcquire(n) 104 require.Equal(t, int64(-1), count) 105 require.Equal(t, int64(-1), s.Count()) 106 107 count = s.Release(n + 1) 108 require.Equal(t, n, count) 109 110 select { 111 case call := <-callCh: 112 require.Equal(t, acquireCall{n, 0, nil}, call) 113 case <-ctx.Done(): 114 t.Fatal(ctx.Err()) 115 } 116 117 require.Equal(t, int64(0), s.Count()) 118 } 119 120 // TestCancel tests that cancelling the context passed into Acquire 121 // causes it to return an error. 122 func TestCancel(t *testing.T) { 123 ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 124 defer cancel() 125 126 ctx2, cancel2 := context.WithCancel(ctx) 127 defer cancel2() 128 129 var n int64 = 10 130 131 s := NewSemaphore() 132 require.Equal(t, int64(0), s.Count()) 133 134 // Do this before spawning the goroutine, so that 135 // callAcquire() will always return a count of n-1. 136 count := s.Release(n - 1) 137 require.Equal(t, n-1, count) 138 require.Equal(t, n-1, s.Count()) 139 140 callCh := make(chan acquireCall, 1) 141 go func() { 142 callCh <- callAcquire(ctx2, s, n) 143 }() 144 145 requireNoCall(t, callCh) 146 147 cancel2() 148 require.Equal(t, n-1, s.Count()) 149 150 select { 151 case call := <-callCh: 152 call.err = errors.Cause(call.err) 153 require.Equal(t, acquireCall{n, n - 1, context.Canceled}, call) 154 case <-ctx.Done(): 155 t.Fatal(ctx.Err()) 156 } 157 158 require.Equal(t, n-1, s.Count()) 159 } 160 161 // TestSerialRelease tests that Release(1) causes exactly one waiting 162 // Acquire(1) to wake up at a time. 163 func TestSerialRelease(t *testing.T) { 164 ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 165 defer cancel() 166 167 acquirerCount := 100 168 169 s := NewSemaphore() 170 acquireCount := 0 171 callCh := make(chan acquireCall, acquirerCount) 172 for i := 0; i < acquirerCount; i++ { 173 go func() { 174 call := callAcquire(ctx, s, 1) 175 acquireCount++ 176 callCh <- call 177 }() 178 } 179 180 for i := 0; i < acquirerCount; i++ { 181 requireNoCall(t, callCh) 182 183 count := s.Release(1) 184 require.Equal(t, int64(1), count) 185 186 select { 187 case call := <-callCh: 188 require.Equal(t, acquireCall{1, 0, nil}, call) 189 case <-ctx.Done(): 190 t.Fatal(ctx.Err()) 191 } 192 193 requireNoCall(t, callCh) 194 195 require.Equal(t, int64(0), s.Count()) 196 } 197 198 // acquireCount should have been incremented race-free. 199 require.Equal(t, acquirerCount, acquireCount) 200 } 201 202 // TestAcquireDifferentSizes tests the scenario where there are 203 // multiple acquirers for different sizes, and we release each size in 204 // increasing order. 205 func TestAcquireDifferentSizes(t *testing.T) { 206 ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 207 defer cancel() 208 209 acquirerCount := 10 210 211 s := NewSemaphore() 212 acquireCount := 0 213 callCh := make(chan acquireCall, acquirerCount) 214 for i := 0; i < acquirerCount; i++ { 215 go func(i int) { 216 call := callAcquire(ctx, s, int64(i+1)) 217 acquireCount++ 218 callCh <- call 219 }(i) 220 } 221 222 for i := 0; i < acquirerCount; i++ { 223 requireNoCall(t, callCh) 224 225 if i == 0 { 226 require.Equal(t, int64(0), s.Count()) 227 } else { 228 count := s.Release(int64(i)) 229 require.Equal(t, int64(i), count) 230 require.Equal(t, int64(i), s.Count()) 231 } 232 233 requireNoCall(t, callCh) 234 235 count := s.Release(1) 236 require.Equal(t, int64(i+1), count) 237 238 select { 239 case call := <-callCh: 240 require.Equal(t, acquireCall{int64(i + 1), 0, nil}, call) 241 case <-ctx.Done(): 242 t.Fatalf("err=%+v, i=%d", ctx.Err(), i) 243 } 244 245 requireNoCall(t, callCh) 246 247 require.Equal(t, int64(0), s.Count()) 248 } 249 250 // acquireCount should have been incremented race-free. 251 require.Equal(t, acquirerCount, acquireCount) 252 } 253 254 func TestAcquirePanic(t *testing.T) { 255 s := NewSemaphore() 256 ctx := context.Background() 257 require.Panics(t, func() { 258 _, _ = s.Acquire(ctx, 0) 259 }) 260 require.Panics(t, func() { 261 _, _ = s.Acquire(ctx, -1) 262 }) 263 } 264 265 func TestForceAcquirePanic(t *testing.T) { 266 s := NewSemaphore() 267 require.Panics(t, func() { 268 s.ForceAcquire(0) 269 }) 270 require.Panics(t, func() { 271 s.ForceAcquire(-1) 272 }) 273 s.ForceAcquire(2) 274 require.Panics(t, func() { 275 s.ForceAcquire(math.MaxInt64) 276 }) 277 } 278 279 func TestReleasePanic(t *testing.T) { 280 s := NewSemaphore() 281 require.Panics(t, func() { 282 s.Release(0) 283 }) 284 require.Panics(t, func() { 285 s.Release(-1) 286 }) 287 s.Release(1) 288 require.Panics(t, func() { 289 s.Release(math.MaxInt64) 290 }) 291 }