github.com/rudderlabs/rudder-go-kit@v0.30.0/throttling/throttling.go (about)

     1  package throttling
     2  
     3  import (
     4  	"context"
     5  	_ "embed"
     6  	"fmt"
     7  	"strconv"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/go-redis/redis/v8"
    12  
    13  	"github.com/rudderlabs/rudder-go-kit/stats"
    14  )
    15  
    16  /*
    17  TODOs:
    18  * guard against concurrency? according to benchmarks, Redis performs better if we have no more than 16 routines
    19    * see https://github.com/rudderlabs/redis-throttling-playground/blob/main/Benchmarks.md#best-concurrency-setting-with-sortedset---save-1-1-and---appendonly-yes
    20  */
    21  
    22  var (
    23  	//go:embed lua/gcra.lua
    24  	gcraLua         string
    25  	gcraRedisScript *redis.Script
    26  	//go:embed lua/sortedset.lua
    27  	sortedSetLua    string
    28  	sortedSetScript *redis.Script
    29  )
    30  
    31  func init() {
    32  	gcraRedisScript = redis.NewScript(gcraLua)
    33  	sortedSetScript = redis.NewScript(sortedSetLua)
    34  }
    35  
    36  type redisSpeaker interface {
    37  	redis.Scripter
    38  	redisSortedSetRemover
    39  }
    40  
    41  type statsCollector interface {
    42  	NewTaggedStat(name, statType string, tags stats.Tags) stats.Measurement
    43  }
    44  
    45  type Limiter struct {
    46  	// for Redis configurations
    47  	// a default redisSpeaker should always be provided for Redis configurations
    48  	redisSpeaker redisSpeaker
    49  
    50  	// for in-memory configurations
    51  	gcra *gcra
    52  
    53  	// other flags
    54  	useGCRA   bool
    55  	gcraBurst int64
    56  
    57  	// metrics
    58  	statsCollector statsCollector
    59  }
    60  
    61  func New(options ...Option) (*Limiter, error) {
    62  	rl := &Limiter{}
    63  	for i := range options {
    64  		options[i](rl)
    65  	}
    66  	if rl.statsCollector == nil {
    67  		rl.statsCollector = stats.Default
    68  	}
    69  	if rl.redisSpeaker != nil {
    70  		return rl, nil
    71  	}
    72  	// Default to in-memory GCRA
    73  	rl.gcra = &gcra{}
    74  	rl.useGCRA = true
    75  	return rl, nil
    76  }
    77  
    78  // Allow returns true if the limit is not exceeded, false otherwise.
    79  func (l *Limiter) Allow(ctx context.Context, cost, rate, window int64, key string) (
    80  	bool, func(context.Context) error, error,
    81  ) {
    82  	allowed, _, tr, err := l.allow(ctx, cost, rate, window, key)
    83  	return allowed, tr, err
    84  }
    85  
    86  // AllowAfter returns true if the limit is not exceeded, false otherwise.
    87  // Additionally, it returns the time.Duration until the next allowed request.
    88  func (l *Limiter) AllowAfter(ctx context.Context, cost, rate, window int64, key string) (
    89  	bool, time.Duration, func(context.Context) error, error,
    90  ) {
    91  	return l.allow(ctx, cost, rate, window, key)
    92  }
    93  
    94  func (l *Limiter) allow(ctx context.Context, cost, rate, window int64, key string) (
    95  	bool, time.Duration, func(context.Context) error, error,
    96  ) {
    97  	if cost < 1 {
    98  		return false, 0, nil, fmt.Errorf("cost must be greater than 0")
    99  	}
   100  	if rate < 1 {
   101  		return false, 0, nil, fmt.Errorf("rate must be greater than 0")
   102  	}
   103  	if window < 1 {
   104  		return false, 0, nil, fmt.Errorf("window must be greater than 0")
   105  	}
   106  	if key == "" {
   107  		return false, 0, nil, fmt.Errorf("key must not be empty")
   108  	}
   109  
   110  	if l.redisSpeaker != nil {
   111  		if l.useGCRA {
   112  			defer l.getTimer(key, "redis-gcra", rate, window)()
   113  			_, allowed, retryAfter, tr, err := l.redisGCRA(ctx, cost, rate, window, key)
   114  			return allowed, retryAfter, tr, err
   115  		}
   116  
   117  		defer l.getTimer(key, "redis-sorted-set", rate, window)()
   118  		_, allowed, retryAfter, tr, err := l.redisSortedSet(ctx, cost, rate, window, key)
   119  		return allowed, retryAfter, tr, err
   120  	}
   121  
   122  	defer l.getTimer(key, "gcra", rate, window)()
   123  	allowed, retryAfter, tr, err := l.gcraLimit(ctx, cost, rate, window, key)
   124  	return allowed, retryAfter, tr, err
   125  }
   126  
   127  func (l *Limiter) redisSortedSet(ctx context.Context, cost, rate, window int64, key string) (
   128  	time.Duration, bool, time.Duration, func(context.Context) error, error,
   129  ) {
   130  	res, err := sortedSetScript.Run(ctx, l.redisSpeaker, []string{key}, cost, rate, window).Result()
   131  	if err != nil {
   132  		return 0, false, 0, nil, fmt.Errorf("could not run SortedSet Redis script: %v", err)
   133  	}
   134  
   135  	result, ok := res.([]interface{})
   136  	if !ok {
   137  		return 0, false, 0, nil, fmt.Errorf("unexpected result from SortedSet Redis script of type %T: %v", res, res)
   138  	}
   139  	if len(result) != 3 {
   140  		return 0, false, 0, nil, fmt.Errorf("unexpected result from SortedSet Redis script of length %d: %+v", len(result), result)
   141  	}
   142  
   143  	t, ok := result[0].(int64)
   144  	if !ok {
   145  		return 0, false, 0, nil, fmt.Errorf("unexpected result[0] from SortedSet Redis script of type %T: %v", result[0], result[0])
   146  	}
   147  	redisTime := time.Duration(t) * time.Microsecond
   148  
   149  	members, ok := result[1].(string)
   150  	if !ok {
   151  		return redisTime, false, 0, nil, fmt.Errorf("unexpected result[1] from SortedSet Redis script of type %T: %v", result[1], result[1])
   152  	}
   153  	if members == "" { // limit exceeded
   154  		retryAfter, ok := result[2].(int64)
   155  		if !ok {
   156  			return redisTime, false, 0, nil, fmt.Errorf("unexpected result[2] from SortedSet Redis script of type %T: %v", result[2], result[2])
   157  		}
   158  		return redisTime, false, time.Duration(retryAfter) * time.Microsecond, nil, nil
   159  	}
   160  
   161  	r := &sortedSetRedisReturn{
   162  		key:     key,
   163  		members: strings.Split(members, ","),
   164  		remover: l.redisSpeaker,
   165  	}
   166  	return redisTime, true, 0, r.Return, nil
   167  }
   168  
   169  func (l *Limiter) redisGCRA(ctx context.Context, cost, rate, window int64, key string) (
   170  	time.Duration, bool, time.Duration, func(context.Context) error, error,
   171  ) {
   172  	burst := rate
   173  	if l.gcraBurst > 0 {
   174  		burst = l.gcraBurst
   175  	}
   176  	res, err := gcraRedisScript.Run(ctx, l.redisSpeaker, []string{key}, burst, rate, window, cost).Result()
   177  	if err != nil {
   178  		return 0, false, 0, nil, fmt.Errorf("could not run GCRA Redis script: %v", err)
   179  	}
   180  
   181  	result, ok := res.([]any)
   182  	if !ok {
   183  		return 0, false, 0, nil, fmt.Errorf("unexpected result from GCRA Redis script of type %T: %v", res, res)
   184  	}
   185  	if len(result) != 5 {
   186  		return 0, false, 0, nil, fmt.Errorf("unexpected result from GCRA Redis scrip of length %d: %+v", len(result), result)
   187  	}
   188  
   189  	t, ok := result[0].(int64)
   190  	if !ok {
   191  		return 0, false, 0, nil, fmt.Errorf("unexpected result[0] from GCRA Redis script of type %T: %v", result[0], result[0])
   192  	}
   193  	redisTime := time.Duration(t) * time.Microsecond
   194  
   195  	allowed, ok := result[1].(int64)
   196  	if !ok {
   197  		return redisTime, false, 0, nil, fmt.Errorf("unexpected result[1] from GCRA Redis script of type %T: %v", result[1], result[1])
   198  	}
   199  	if allowed < 1 { // limit exceeded
   200  		retryAfter, ok := result[3].(int64)
   201  		if !ok {
   202  			return redisTime, false, 0, nil, fmt.Errorf("unexpected result[3] from GCRA Redis script of type %T: %v", result[3], result[3])
   203  		}
   204  		return redisTime, false, time.Duration(retryAfter) * time.Microsecond, nil, nil
   205  	}
   206  
   207  	r := &unsupportedReturn{}
   208  	return redisTime, true, 0, r.Return, nil
   209  }
   210  
   211  func (l *Limiter) gcraLimit(ctx context.Context, cost, rate, window int64, key string) (
   212  	bool, time.Duration, func(context.Context) error, error,
   213  ) {
   214  	burst := rate
   215  	if l.gcraBurst > 0 {
   216  		burst = l.gcraBurst
   217  	}
   218  	allowed, retryAfter, err := l.gcra.limit(ctx, key, cost, burst, rate, window)
   219  	if err != nil {
   220  		return false, 0, nil, fmt.Errorf("could not limit: %w", err)
   221  	}
   222  	if !allowed {
   223  		return false, retryAfter, nil, nil // limit exceeded
   224  	}
   225  	r := &unsupportedReturn{}
   226  	return true, 0, r.Return, nil
   227  }
   228  
   229  func (l *Limiter) getTimer(key, algo string, rate, window int64) func() {
   230  	m := l.statsCollector.NewTaggedStat("throttling", stats.TimerType, stats.Tags{
   231  		"key":    key,
   232  		"algo":   algo,
   233  		"rate":   strconv.FormatInt(rate, 10),
   234  		"window": strconv.FormatInt(window, 10),
   235  	})
   236  	start := time.Now()
   237  	return func() {
   238  		m.Since(start)
   239  	}
   240  }