github.com/grailbio/base@v0.0.11/limiter/batch.go (about) 1 // Copyright 2021 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package limiter 6 7 import ( 8 "context" 9 "fmt" 10 "sync" 11 "time" 12 13 "github.com/grailbio/base/sync/ctxsync" 14 "golang.org/x/time/rate" 15 ) 16 17 // BatchLimiter provides the ability to batch calls and apply a rate limit (on the batches). 18 // Users have to provide an implementation of BatchApi and a rate.Limiter. 19 // Thereafter callers can concurrently Do calls for each individual ID and the BatchLimiter will 20 // batch calls (whenever appropriate) while respecting the rate limit. 21 // Individual requests are serviced in the order of submission. 22 type BatchLimiter struct { 23 api BatchApi 24 limiter *rate.Limiter 25 wait time.Duration 26 27 mu sync.Mutex 28 // pending is the list of pending ids in the order of submission 29 pending []ID 30 // results maps each submitted ID to its result. 31 results map[ID]*Result 32 } 33 34 // BatchApi needs to be implemented in order to use BatchLimiter. 35 type BatchApi interface { 36 // MaxPerBatch is the max number of ids to call per `Do` (zero implies no limit). 37 MaxPerBatch() int 38 39 // Do the batch call with the given map of IDs to Results. 40 // The implementation must call Result.Set to provide the Value or Err (as applicable) for the every ID. 41 // At the end of this call, if Result.Set was not called on the result of a particular ID, 42 // the corresponding ID's `Do` call will get ErrNoResult. 43 Do(map[ID]*Result) 44 } 45 46 // ID is the identifier of each call. 47 type ID interface{} 48 49 // Result is the result of an API call for a given id. 50 type Result struct { 51 mu sync.Mutex 52 cond *ctxsync.Cond 53 id ID 54 value interface{} 55 err error 56 done bool 57 nWaiters int 58 } 59 60 // Set sets the result of a given id with the given value v and error err. 61 func (r *Result) Set(v interface{}, err error) { 62 r.mu.Lock() 63 defer r.mu.Unlock() 64 r.done = true 65 r.value = v 66 r.err = err 67 r.cond.Broadcast() 68 } 69 70 func (r *Result) doneC() <-chan struct{} { 71 r.mu.Lock() 72 return r.cond.Done() 73 } 74 75 // NewBatchLimiter returns a new BatchLimiter which will call the given batch API 76 // as per the limits set by the given rate limiter. 77 func NewBatchLimiter(api BatchApi, limiter *rate.Limiter) *BatchLimiter { 78 eventsPerSecond := limiter.Limit() 79 if eventsPerSecond == 0 { 80 panic("limiter does not allow any events") 81 } 82 d := float64(time.Second) / float64(eventsPerSecond) 83 wait := time.Duration(d) 84 return &BatchLimiter{api: api, limiter: limiter, wait: wait, results: make(map[ID]*Result)} 85 } 86 87 var ErrNoResult = fmt.Errorf("no result") 88 89 // Do submits the given ID to the batch limiter and returns the result or an error. 90 // If the returned error is ErrNoResult, it indicates that the batch call did not produce any result for the given ID. 91 // Callers may then apply their own retry strategy if necessary. 92 // Do merges duplicate calls if the IDs are of a comparable type (and if the result is still pending) 93 // However, de-duplication is not guaranteed. 94 // Callers can avoid de-duplication by using a pointer type instead. 95 func (l *BatchLimiter) Do(ctx context.Context, id ID) (interface{}, error) { 96 var t *time.Timer 97 defer func() { 98 if t != nil { 99 t.Stop() 100 } 101 }() 102 r := l.register(id) 103 defer l.unregister(r) 104 for { 105 if done, v, err := l.get(r); done { 106 return v, err 107 } 108 if l.limiter.Allow() { 109 m := l.claim() 110 if len(m) > 0 { 111 l.api.Do(m) 112 l.update(m) 113 continue 114 } 115 } 116 // Wait half the interval to increase chances of making the next call as early as possible. 117 d := l.wait / 2 118 if t == nil { 119 t = time.NewTimer(d) 120 } else { 121 t.Reset(d) 122 } 123 select { 124 case <-ctx.Done(): 125 return nil, ctx.Err() 126 case <-r.doneC(): 127 case <-t.C: 128 } 129 } 130 } 131 132 // register registers the given id. 133 func (l *BatchLimiter) register(id ID) *Result { 134 l.mu.Lock() 135 defer l.mu.Unlock() 136 if _, ok := l.results[id]; !ok { 137 l.pending = append(l.pending, id) 138 r := &Result{id: id} 139 r.cond = ctxsync.NewCond(&r.mu) 140 l.results[id] = r 141 } 142 r := l.results[id] 143 r.mu.Lock() 144 r.nWaiters += 1 145 r.mu.Unlock() 146 return r 147 } 148 149 // unregister indicates that the calling goroutine is no longer interested in the given result. 150 func (l *BatchLimiter) unregister(r *Result) { 151 var remove bool 152 r.mu.Lock() 153 r.nWaiters -= 1 154 remove = r.nWaiters == 0 155 r.mu.Unlock() 156 if remove { 157 l.mu.Lock() 158 delete(l.results, r.id) 159 l.mu.Unlock() 160 } 161 } 162 163 // get returns whether the result is done and the value and error. 164 func (l *BatchLimiter) get(r *Result) (bool, interface{}, error) { 165 r.mu.Lock() 166 defer r.mu.Unlock() 167 return r.done, r.value, r.err 168 } 169 170 // update updates the internal results using the given ones. 171 // update also sets ErrNoResult as the error result for IDs for which `Result.Set` was not called. 172 func (l *BatchLimiter) update(results map[ID]*Result) { 173 for _, r := range results { 174 r.mu.Lock() 175 if !r.done { 176 r.done, r.err = true, ErrNoResult 177 } 178 r.mu.Unlock() 179 } 180 } 181 182 // claim claims pending ids and returns a mapping of those ids to their results. 183 func (l *BatchLimiter) claim() map[ID]*Result { 184 l.mu.Lock() 185 defer l.mu.Unlock() 186 max := l.api.MaxPerBatch() 187 if max == 0 { 188 max = len(l.pending) 189 } 190 claimed := make(map[ID]*Result) 191 i := 0 192 for ; i < len(l.pending) && len(claimed) < max; i++ { 193 id := l.pending[i] 194 r := l.results[id] 195 if r == nil { 196 continue 197 } 198 r.mu.Lock() 199 if !r.done { 200 claimed[id] = r 201 } 202 r.mu.Unlock() 203 } 204 // Remove the claimed ids from the pending list. 205 l.pending = l.pending[i:] 206 return claimed 207 }