github.com/influx6/npkg@v0.8.8/nrates/nrates.go (about)

     1  package nrates
     2  
     3  import (
     4  	"context"
     5  	"strconv"
     6  	"time"
     7  
     8  	"github.com/go-redis/redis/v8"
     9  	"github.com/influx6/npkg/nerror"
    10  	"github.com/influx6/npkg/ntrace"
    11  	openTracing "github.com/opentracing/opentracing-go"
    12  )
    13  
    14  // Rate is the rate of allowed requests. We support
    15  // r/min and r/second.
    16  type Rate int
    17  
    18  const (
    19  	// PerSecond allows us to accept x requests per second
    20  	PerSecond Rate = iota
    21  	// PerMinute allows us to accept x requests per minute
    22  	PerMinute
    23  )
    24  
    25  // HHMMSS formats a timestamp as HH:MM:SS
    26  // Reference: https://yourbasic.org/golang/format-parse-string-time-date-example/
    27  // const HHMMSS = "15:04:05"
    28  
    29  // HHMM formats a timestamp as HH:MM
    30  // Reference: https://yourbasic.org/golang/format-parse-string-time-date-example/
    31  // const HHMM = "15:04"
    32  
    33  type RedisIncr struct {
    34  	Client *redis.Client
    35  }
    36  
    37  func NewRedisIncr(config *redis.Options) (*RedisIncr, error) {
    38  	client := redis.NewClient(config)
    39  
    40  	return &RedisIncr{
    41  		Client: client,
    42  	}, nil
    43  }
    44  
    45  func (b *RedisIncr) Reset(ctx context.Context, r Request) error {
    46  	var span openTracing.Span
    47  	if ctx, span = ntrace.NewMethodSpanFromContext(ctx); span != nil {
    48  		defer span.Finish()
    49  	}
    50  
    51  	status := b.Client.Ping(ctx)
    52  	if err := status.Err(); err != nil {
    53  		return nerror.WrapOnly(err)
    54  	}
    55  
    56  	_, err := b.Client.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
    57  		return pipe.Set(ctx, r.Owner(), 0, 0).Err()
    58  	})
    59  	if err != nil {
    60  		return nerror.WrapOnly(err)
    61  	}
    62  	return nil
    63  }
    64  
    65  func (b *RedisIncr) Count(ctx context.Context, r Request) (int64, error) {
    66  	var span openTracing.Span
    67  	if ctx, span = ntrace.NewMethodSpanFromContext(ctx); span != nil {
    68  		defer span.Finish()
    69  	}
    70  
    71  	status := b.Client.Ping(ctx)
    72  	if err := status.Err(); err != nil {
    73  		return -1, nerror.WrapOnly(err)
    74  	}
    75  
    76  	var res = b.Client.Get(ctx, r.Owner())
    77  	if err := res.Err(); err != nil {
    78  		return -1, nerror.WrapOnly(err)
    79  	}
    80  
    81  	var count, readErr = strconv.Atoi(res.Val())
    82  	if readErr != nil {
    83  		return -1, nerror.WrapOnly(readErr)
    84  	}
    85  
    86  	return int64(count), nil
    87  }
    88  
    89  func (b *RedisIncr) Inc(ctx context.Context, r Request, dur time.Duration) error {
    90  	var span openTracing.Span
    91  	if ctx, span = ntrace.NewMethodSpanFromContext(ctx); span != nil {
    92  		defer span.Finish()
    93  	}
    94  
    95  	status := b.Client.Ping(ctx)
    96  	if err := status.Err(); err != nil {
    97  		return nerror.WrapOnly(err)
    98  	}
    99  
   100  	_, err := b.Client.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
   101  		pipe.Incr(ctx, r.Owner())
   102  		pipe.Expire(ctx, r.Owner(), dur)
   103  		return nil
   104  	})
   105  	if err != nil {
   106  		return nerror.WrapOnly(err)
   107  	}
   108  	return nil
   109  }
   110  
   111  type Request interface {
   112  	Owner() string
   113  	Data() interface{}
   114  }
   115  
   116  type IncrementStore interface {
   117  	Reset(ctx context.Context, r Request) error
   118  	Count(ctx context.Context, r Request) (int64, error)
   119  	Inc(ctx context.Context, r Request, dur time.Duration) (int64, error)
   120  }
   121  
   122  // NewRateLimiter returns a new Limiter.
   123  func NewFactory(db IncrementStore, rate Rate) *LimiterFactory {
   124  	return &LimiterFactory{Store: db, Rate: rate}
   125  }
   126  
   127  type LimiterFactory struct {
   128  	Store IncrementStore
   129  	Rate  Rate
   130  }
   131  
   132  // NewLimiter creates a new Limiter.
   133  func (f LimiterFactory) New(max int64) *RateLimiter {
   134  	return &RateLimiter{
   135  		store: f.Store,
   136  		rate:  f.Rate,
   137  		max:   max,
   138  	}
   139  }
   140  
   141  // NewLimiter creates a new Limiter.
   142  func (f LimiterFactory) NewLimiter(rate Rate, max int64) *RateLimiter {
   143  	return &RateLimiter{
   144  		store: f.Store,
   145  		rate:  rate,
   146  		max:   max,
   147  	}
   148  }
   149  
   150  type RateLimiter struct {
   151  	store IncrementStore
   152  	rate  Rate
   153  	max   int64
   154  }
   155  
   156  // RateLimit applies basic rate limiting to an HTTP request as described
   157  // in Redis' onboarding documentation.
   158  // Reference: https://redislabs.com/redis-best-practices/basic-rate-limiting/
   159  func (l *RateLimiter) RateLimit(ctx context.Context, r Request) error {
   160  	var span openTracing.Span
   161  	if ctx, span = ntrace.NewMethodSpanFromContext(ctx); span != nil {
   162  		defer span.Finish()
   163  	}
   164  
   165  	var expiry time.Duration
   166  
   167  	if l.rate == PerSecond {
   168  		expiry = time.Second
   169  	} else {
   170  		expiry = time.Minute
   171  	}
   172  
   173  	var count, err = l.store.Count(ctx, r)
   174  	if err != nil {
   175  		return nerror.WrapOnly(err)
   176  	}
   177  
   178  	if count >= l.max {
   179  		return nerror.New("requests are throttled, try again later")
   180  	}
   181  
   182  	_, err = l.store.Inc(ctx, r, expiry)
   183  	if err != nil {
   184  		return nerror.WrapOnly(err)
   185  	}
   186  	return nil
   187  }