github.com/influx6/npkg@v0.8.8/nrates/nrates.go (about) 1 package nrates 2 3 import ( 4 "context" 5 "strconv" 6 "time" 7 8 "github.com/go-redis/redis/v8" 9 "github.com/influx6/npkg/nerror" 10 "github.com/influx6/npkg/ntrace" 11 openTracing "github.com/opentracing/opentracing-go" 12 ) 13 14 // Rate is the rate of allowed requests. We support 15 // r/min and r/second. 16 type Rate int 17 18 const ( 19 // PerSecond allows us to accept x requests per second 20 PerSecond Rate = iota 21 // PerMinute allows us to accept x requests per minute 22 PerMinute 23 ) 24 25 // HHMMSS formats a timestamp as HH:MM:SS 26 // Reference: https://yourbasic.org/golang/format-parse-string-time-date-example/ 27 // const HHMMSS = "15:04:05" 28 29 // HHMM formats a timestamp as HH:MM 30 // Reference: https://yourbasic.org/golang/format-parse-string-time-date-example/ 31 // const HHMM = "15:04" 32 33 type RedisIncr struct { 34 Client *redis.Client 35 } 36 37 func NewRedisIncr(config *redis.Options) (*RedisIncr, error) { 38 client := redis.NewClient(config) 39 40 return &RedisIncr{ 41 Client: client, 42 }, nil 43 } 44 45 func (b *RedisIncr) Reset(ctx context.Context, r Request) error { 46 var span openTracing.Span 47 if ctx, span = ntrace.NewMethodSpanFromContext(ctx); span != nil { 48 defer span.Finish() 49 } 50 51 status := b.Client.Ping(ctx) 52 if err := status.Err(); err != nil { 53 return nerror.WrapOnly(err) 54 } 55 56 _, err := b.Client.TxPipelined(ctx, func(pipe redis.Pipeliner) error { 57 return pipe.Set(ctx, r.Owner(), 0, 0).Err() 58 }) 59 if err != nil { 60 return nerror.WrapOnly(err) 61 } 62 return nil 63 } 64 65 func (b *RedisIncr) Count(ctx context.Context, r Request) (int64, error) { 66 var span openTracing.Span 67 if ctx, span = ntrace.NewMethodSpanFromContext(ctx); span != nil { 68 defer span.Finish() 69 } 70 71 status := b.Client.Ping(ctx) 72 if err := status.Err(); err != nil { 73 return -1, nerror.WrapOnly(err) 74 } 75 76 var res = b.Client.Get(ctx, r.Owner()) 77 if err := res.Err(); err != nil { 78 return -1, nerror.WrapOnly(err) 79 } 80 81 var count, readErr = strconv.Atoi(res.Val()) 82 if readErr != nil { 83 return -1, nerror.WrapOnly(readErr) 84 } 85 86 return int64(count), nil 87 } 88 89 func (b *RedisIncr) Inc(ctx context.Context, r Request, dur time.Duration) error { 90 var span openTracing.Span 91 if ctx, span = ntrace.NewMethodSpanFromContext(ctx); span != nil { 92 defer span.Finish() 93 } 94 95 status := b.Client.Ping(ctx) 96 if err := status.Err(); err != nil { 97 return nerror.WrapOnly(err) 98 } 99 100 _, err := b.Client.TxPipelined(ctx, func(pipe redis.Pipeliner) error { 101 pipe.Incr(ctx, r.Owner()) 102 pipe.Expire(ctx, r.Owner(), dur) 103 return nil 104 }) 105 if err != nil { 106 return nerror.WrapOnly(err) 107 } 108 return nil 109 } 110 111 type Request interface { 112 Owner() string 113 Data() interface{} 114 } 115 116 type IncrementStore interface { 117 Reset(ctx context.Context, r Request) error 118 Count(ctx context.Context, r Request) (int64, error) 119 Inc(ctx context.Context, r Request, dur time.Duration) (int64, error) 120 } 121 122 // NewRateLimiter returns a new Limiter. 123 func NewFactory(db IncrementStore, rate Rate) *LimiterFactory { 124 return &LimiterFactory{Store: db, Rate: rate} 125 } 126 127 type LimiterFactory struct { 128 Store IncrementStore 129 Rate Rate 130 } 131 132 // NewLimiter creates a new Limiter. 133 func (f LimiterFactory) New(max int64) *RateLimiter { 134 return &RateLimiter{ 135 store: f.Store, 136 rate: f.Rate, 137 max: max, 138 } 139 } 140 141 // NewLimiter creates a new Limiter. 142 func (f LimiterFactory) NewLimiter(rate Rate, max int64) *RateLimiter { 143 return &RateLimiter{ 144 store: f.Store, 145 rate: rate, 146 max: max, 147 } 148 } 149 150 type RateLimiter struct { 151 store IncrementStore 152 rate Rate 153 max int64 154 } 155 156 // RateLimit applies basic rate limiting to an HTTP request as described 157 // in Redis' onboarding documentation. 158 // Reference: https://redislabs.com/redis-best-practices/basic-rate-limiting/ 159 func (l *RateLimiter) RateLimit(ctx context.Context, r Request) error { 160 var span openTracing.Span 161 if ctx, span = ntrace.NewMethodSpanFromContext(ctx); span != nil { 162 defer span.Finish() 163 } 164 165 var expiry time.Duration 166 167 if l.rate == PerSecond { 168 expiry = time.Second 169 } else { 170 expiry = time.Minute 171 } 172 173 var count, err = l.store.Count(ctx, r) 174 if err != nil { 175 return nerror.WrapOnly(err) 176 } 177 178 if count >= l.max { 179 return nerror.New("requests are throttled, try again later") 180 } 181 182 _, err = l.store.Inc(ctx, r, expiry) 183 if err != nil { 184 return nerror.WrapOnly(err) 185 } 186 return nil 187 }