github.com/grailbio/base@v0.0.11/limiter/batch.go (about)

     1  // Copyright 2021 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package limiter
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/grailbio/base/sync/ctxsync"
    14  	"golang.org/x/time/rate"
    15  )
    16  
    17  // BatchLimiter provides the ability to batch calls and apply a rate limit (on the batches).
    18  // Users have to provide an implementation of BatchApi and a rate.Limiter.
    19  // Thereafter callers can concurrently Do calls for each individual ID and the BatchLimiter will
    20  // batch calls (whenever appropriate) while respecting the rate limit.
    21  // Individual requests are serviced in the order of submission.
    22  type BatchLimiter struct {
    23  	api     BatchApi
    24  	limiter *rate.Limiter
    25  	wait    time.Duration
    26  
    27  	mu sync.Mutex
    28  	// pending is the list of pending ids in the order of submission
    29  	pending []ID
    30  	// results maps each submitted ID to its result.
    31  	results map[ID]*Result
    32  }
    33  
    34  // BatchApi needs to be implemented in order to use BatchLimiter.
    35  type BatchApi interface {
    36  	// MaxPerBatch is the max number of ids to call per `Do` (zero implies no limit).
    37  	MaxPerBatch() int
    38  
    39  	// Do the batch call with the given map of IDs to Results.
    40  	// The implementation must call Result.Set to provide the Value or Err (as applicable) for the every ID.
    41  	// At the end of this call, if Result.Set was not called on the result of a particular ID,
    42  	// the corresponding ID's `Do` call will get ErrNoResult.
    43  	Do(map[ID]*Result)
    44  }
    45  
    46  // ID is the identifier of each call.
    47  type ID interface{}
    48  
    49  // Result is the result of an API call for a given id.
    50  type Result struct {
    51  	mu       sync.Mutex
    52  	cond     *ctxsync.Cond
    53  	id       ID
    54  	value    interface{}
    55  	err      error
    56  	done     bool
    57  	nWaiters int
    58  }
    59  
    60  // Set sets the result of a given id with the given value v and error err.
    61  func (r *Result) Set(v interface{}, err error) {
    62  	r.mu.Lock()
    63  	defer r.mu.Unlock()
    64  	r.done = true
    65  	r.value = v
    66  	r.err = err
    67  	r.cond.Broadcast()
    68  }
    69  
    70  func (r *Result) doneC() <-chan struct{} {
    71  	r.mu.Lock()
    72  	return r.cond.Done()
    73  }
    74  
    75  // NewBatchLimiter returns a new BatchLimiter which will call the given batch API
    76  // as per the limits set by the given rate limiter.
    77  func NewBatchLimiter(api BatchApi, limiter *rate.Limiter) *BatchLimiter {
    78  	eventsPerSecond := limiter.Limit()
    79  	if eventsPerSecond == 0 {
    80  		panic("limiter does not allow any events")
    81  	}
    82  	d := float64(time.Second) / float64(eventsPerSecond)
    83  	wait := time.Duration(d)
    84  	return &BatchLimiter{api: api, limiter: limiter, wait: wait, results: make(map[ID]*Result)}
    85  }
    86  
    87  var ErrNoResult = fmt.Errorf("no result")
    88  
    89  // Do submits the given ID to the batch limiter and returns the result or an error.
    90  // If the returned error is ErrNoResult, it indicates that the batch call did not produce any result for the given ID.
    91  // Callers may then apply their own retry strategy if necessary.
    92  // Do merges duplicate calls if the IDs are of a comparable type (and if the result is still pending)
    93  // However, de-duplication is not guaranteed.
    94  // Callers can avoid de-duplication by using a pointer type instead.
    95  func (l *BatchLimiter) Do(ctx context.Context, id ID) (interface{}, error) {
    96  	var t *time.Timer
    97  	defer func() {
    98  		if t != nil {
    99  			t.Stop()
   100  		}
   101  	}()
   102  	r := l.register(id)
   103  	defer l.unregister(r)
   104  	for {
   105  		if done, v, err := l.get(r); done {
   106  			return v, err
   107  		}
   108  		if l.limiter.Allow() {
   109  			m := l.claim()
   110  			if len(m) > 0 {
   111  				l.api.Do(m)
   112  				l.update(m)
   113  				continue
   114  			}
   115  		}
   116  		// Wait half the interval to increase chances of making the next call as early as possible.
   117  		d := l.wait / 2
   118  		if t == nil {
   119  			t = time.NewTimer(d)
   120  		} else {
   121  			t.Reset(d)
   122  		}
   123  		select {
   124  		case <-ctx.Done():
   125  			return nil, ctx.Err()
   126  		case <-r.doneC():
   127  		case <-t.C:
   128  		}
   129  	}
   130  }
   131  
   132  // register registers the given id.
   133  func (l *BatchLimiter) register(id ID) *Result {
   134  	l.mu.Lock()
   135  	defer l.mu.Unlock()
   136  	if _, ok := l.results[id]; !ok {
   137  		l.pending = append(l.pending, id)
   138  		r := &Result{id: id}
   139  		r.cond = ctxsync.NewCond(&r.mu)
   140  		l.results[id] = r
   141  	}
   142  	r := l.results[id]
   143  	r.mu.Lock()
   144  	r.nWaiters += 1
   145  	r.mu.Unlock()
   146  	return r
   147  }
   148  
   149  // unregister indicates that the calling goroutine is no longer interested in the given result.
   150  func (l *BatchLimiter) unregister(r *Result) {
   151  	var remove bool
   152  	r.mu.Lock()
   153  	r.nWaiters -= 1
   154  	remove = r.nWaiters == 0
   155  	r.mu.Unlock()
   156  	if remove {
   157  		l.mu.Lock()
   158  		delete(l.results, r.id)
   159  		l.mu.Unlock()
   160  	}
   161  }
   162  
   163  // get returns whether the result is done and the value and error.
   164  func (l *BatchLimiter) get(r *Result) (bool, interface{}, error) {
   165  	r.mu.Lock()
   166  	defer r.mu.Unlock()
   167  	return r.done, r.value, r.err
   168  }
   169  
   170  // update updates the internal results using the given ones.
   171  // update also sets ErrNoResult as the error result for IDs for which `Result.Set` was not called.
   172  func (l *BatchLimiter) update(results map[ID]*Result) {
   173  	for _, r := range results {
   174  		r.mu.Lock()
   175  		if !r.done {
   176  			r.done, r.err = true, ErrNoResult
   177  		}
   178  		r.mu.Unlock()
   179  	}
   180  }
   181  
   182  // claim claims pending ids and returns a mapping of those ids to their results.
   183  func (l *BatchLimiter) claim() map[ID]*Result {
   184  	l.mu.Lock()
   185  	defer l.mu.Unlock()
   186  	max := l.api.MaxPerBatch()
   187  	if max == 0 {
   188  		max = len(l.pending)
   189  	}
   190  	claimed := make(map[ID]*Result)
   191  	i := 0
   192  	for ; i < len(l.pending) && len(claimed) < max; i++ {
   193  		id := l.pending[i]
   194  		r := l.results[id]
   195  		if r == nil {
   196  			continue
   197  		}
   198  		r.mu.Lock()
   199  		if !r.done {
   200  			claimed[id] = r
   201  		}
   202  		r.mu.Unlock()
   203  	}
   204  	// Remove the claimed ids from the pending list.
   205  	l.pending = l.pending[i:]
   206  	return claimed
   207  }