github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/limiter/batch_test.go (about) 1 // Copyright 2021 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package limiter 6 7 import ( 8 "context" 9 "fmt" 10 "strconv" 11 "sync" 12 "testing" 13 "time" 14 15 "github.com/Schaudge/grailbase/traverse" 16 "golang.org/x/time/rate" 17 ) 18 19 type testBatchApi struct { 20 mu sync.Mutex 21 usePtr bool 22 maxPerBatch int 23 last time.Time 24 perBatchIds [][]string 25 durs []time.Duration 26 idSeenCount map[string]int 27 } 28 29 func (a *testBatchApi) MaxPerBatch() int { return a.maxPerBatch } 30 func (a *testBatchApi) Do(results map[ID]*Result) { 31 a.mu.Lock() 32 defer a.mu.Unlock() 33 now := time.Now() 34 if a.last.IsZero() { 35 a.last = now 36 } 37 ids := make([]string, 0, len(results)) 38 for k, r := range results { 39 var id string 40 if a.usePtr { 41 id = *k.(*string) 42 } else { 43 id = k.(string) 44 } 45 ids = append(ids, id) 46 idSeenCount := a.idSeenCount[id] 47 i, err := strconv.Atoi(id) 48 if err != nil { 49 i = -1 50 } 51 switch { 52 case shouldErr(i): 53 case i%2 == 0: 54 r.Set(nil, fmt.Errorf("failed_%s_count_%d", id, idSeenCount)) 55 default: 56 r.Set(fmt.Sprintf("value-%s", id), nil) 57 } 58 a.idSeenCount[id] = idSeenCount + 1 59 } 60 a.perBatchIds = append(a.perBatchIds, ids) 61 a.durs = append(a.durs, now.Sub(a.last)) 62 a.last = now 63 return 64 } 65 66 func TestSimple(t *testing.T) { 67 a := &testBatchApi{idSeenCount: make(map[string]int)} 68 l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(time.Millisecond), 1)) 69 id := "test" 70 _, _ = l.Do(context.Background(), id) 71 if got, want := a.idSeenCount[id], 1; got != want { 72 t.Errorf("got %d, want %d", got, want) 73 } 74 _, _ = l.Do(context.Background(), id) 75 if got, want := a.idSeenCount[id], 2; got != want { 76 t.Errorf("got %d, want %d", got, want) 77 } 78 } 79 80 func TestCtxCanceled(t *testing.T) { 81 a := &testBatchApi{idSeenCount: make(map[string]int)} 82 l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(time.Second), 1)) 83 id1, id2 := "test1", "test2" 84 _, _ = l.Do(context.Background(), id1) 85 if got, want := a.idSeenCount[id1], 1; got != want { 86 t.Errorf("got %d, want %d", got, want) 87 } 88 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 89 defer cancel() 90 var wg sync.WaitGroup 91 wg.Add(1) 92 go func() { 93 defer wg.Done() 94 _, _ = l.Do(ctx, id1) 95 }() 96 wg.Add(1) 97 go func() { 98 defer wg.Done() 99 _, _ = l.Do(context.Background(), id2) 100 }() 101 wg.Wait() 102 if got, want := a.idSeenCount[id1], 1; got != want { 103 t.Errorf("got %d, want %d", got, want) 104 } 105 if got, want := a.idSeenCount[id2], 1; got != want { 106 t.Errorf("got %d, want %d", got, want) 107 } 108 } 109 110 func TestSometimesDedup(t *testing.T) { 111 const num = 5 112 a := &testBatchApi{idSeenCount: make(map[string]int)} 113 l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(10*time.Millisecond), num)) 114 id := "test" 115 a.mu.Lock() // Locks the batch API. 116 var done sync.WaitGroup 117 done.Add(num) 118 for i := 0; i < num; i++ { 119 go func() { 120 defer done.Done() 121 _, _ = l.Do(context.Background(), id) 122 }() 123 } 124 var allWaiting bool 125 for !allWaiting { 126 l.mu.Lock() 127 r := l.results[id] 128 l.mu.Unlock() 129 if r == nil { 130 time.Sleep(time.Millisecond) 131 continue 132 } 133 r.mu.Lock() 134 allWaiting = r.nWaiters == num 135 r.mu.Unlock() 136 } 137 a.mu.Unlock() // Unlock the batch API. 138 done.Wait() // Wait for all the goroutines on the same ID to complete 139 if got, want := a.idSeenCount[id], 1; got != want { 140 t.Errorf("got %d, want %d", got, want) 141 } 142 } 143 144 func TestNoDedup(t *testing.T) { 145 a := &testBatchApi{usePtr: true, idSeenCount: make(map[string]int)} 146 l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(10*time.Millisecond), 1)) 147 id := "test" 148 a.mu.Lock() // Locks the batch API. 149 var started, done sync.WaitGroup 150 started.Add(5) 151 done.Add(5) 152 for i := 0; i < 5; i++ { 153 go func() { 154 started.Done() 155 id := id 156 _, _ = l.Do(context.Background(), &id) 157 done.Done() 158 }() 159 } 160 started.Wait() // Wait for all the goroutines on the same ID to start 161 a.mu.Unlock() // Unlock the batch API. 162 done.Wait() // Wait for all the goroutines on the same ID to complete 163 if got, want := a.idSeenCount[id], 5; got != want { 164 t.Errorf("got %d, want %d", got, want) 165 } 166 } 167 168 func TestDo(t *testing.T) { 169 testApi(t, &testBatchApi{idSeenCount: make(map[string]int)}, time.Second) 170 } 171 172 func TestDoWithMax5(t *testing.T) { 173 testApi(t, &testBatchApi{maxPerBatch: 5, idSeenCount: make(map[string]int)}, 3*time.Second) 174 } 175 176 func TestDoWithMax8(t *testing.T) { 177 testApi(t, &testBatchApi{maxPerBatch: 8, idSeenCount: make(map[string]int)}, 2*time.Second) 178 } 179 180 type result struct { 181 v string 182 err error 183 } 184 185 func shouldErr(i int) bool { 186 return i%5 == 0 && i%2 != 0 187 } 188 189 func testApi(t *testing.T, a *testBatchApi, timeout time.Duration) { 190 const numIds = 100 191 var interval = 100 * time.Millisecond 192 l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(interval), 1)) 193 var mu sync.Mutex 194 results := make(map[string]result) 195 _ = traverse.Each(numIds, func(i int) error { 196 time.Sleep(time.Duration(i*10) * time.Millisecond) 197 id := fmt.Sprintf("%d", i) 198 ctx, cancel := context.WithTimeout(context.Background(), timeout) 199 defer cancel() 200 v, err := l.Do(ctx, id) 201 mu.Lock() 202 r := result{err: err} 203 if r.err == nil { 204 r.v = v.(string) 205 } 206 results[id] = r 207 mu.Unlock() 208 return nil 209 }) 210 for i := 0; i < numIds; i++ { 211 id := fmt.Sprintf("%d", i) 212 if got, want := a.idSeenCount[id], 1; got != want { 213 t.Errorf("[%v] got %d, want %d", id, got, want) 214 } 215 if shouldErr(i) { 216 if got, want := results[id].err, ErrNoResult; got != want { 217 t.Errorf("[%d] got %v, want %v", i, got, want) 218 } 219 } 220 } 221 for _, dur := range a.durs[1:] { 222 if got, want, diff := dur, interval, (dur - interval).Round(5*time.Millisecond); diff < 0 { 223 t.Errorf("got %v, want %v, diff %v", got, want, diff) 224 } 225 } 226 for i, batchIds := range a.perBatchIds { 227 if want := a.maxPerBatch; want > 0 { 228 if got := len(batchIds); got > want { 229 t.Errorf("got %v, want <=%v", got, want) 230 } 231 } 232 t.Logf("batch %d (after %s): %v", i, a.durs[i].Round(time.Millisecond), batchIds) 233 } 234 }