github.com/avenga/couper@v1.12.2/handler/ratelimit/limiter.go (about) 1 package ratelimit 2 3 import ( 4 "net/http" 5 "runtime/debug" 6 "sync" 7 "time" 8 9 "github.com/avenga/couper/errors" 10 ) 11 12 type Limiter struct { 13 check chan *slowTrip 14 limits RateLimits 15 mu sync.RWMutex 16 transport http.RoundTripper 17 } 18 19 type slowTrip struct { 20 err error 21 out chan *slowTrip 22 quitCh <-chan struct{} 23 req *http.Request 24 res *http.Response 25 } 26 27 func NewLimiter(transport http.RoundTripper, limits RateLimits) *Limiter { 28 if len(limits) == 0 { 29 return nil 30 } 31 32 limiter := &Limiter{ 33 check: make(chan *slowTrip), 34 limits: limits, 35 transport: transport, 36 } 37 38 for _, rl := range limits { 39 // Init the start of a period. 40 rl.periodStart = time.Now() 41 } 42 43 go limiter.slowTripper() 44 45 return limiter 46 } 47 48 func (l *Limiter) RoundTrip(req *http.Request) (*http.Response, error) { 49 outCh := make(chan *slowTrip) 50 51 trip := &slowTrip{ 52 out: outCh, 53 quitCh: l.limits[0].quitCh, 54 req: req, 55 } 56 57 select { 58 case l.check <- trip: 59 case <-req.Context().Done(): 60 return nil, req.Context().Err() 61 } 62 63 trip = <-outCh 64 65 return trip.res, trip.err 66 } 67 68 func (l *Limiter) slowTripper() { 69 defer func() { 70 if rc := recover(); rc != nil { 71 l.limits[0].logger.WithField("panic", string(debug.Stack())).Panic(rc) 72 } 73 }() 74 75 for { 76 select { 77 case <-l.limits[0].quitCh: 78 return 79 case trip := <-l.check: 80 select { 81 case <-trip.req.Context().Done(): 82 // The request was canceled while in the queue. 83 trip.err = trip.req.Context().Err() 84 trip.out <- trip 85 86 // Do not sleep for X canceled requests. 87 continue 88 default: 89 } 90 91 l.mu.Lock() 92 93 if mode, timeToWait := l.checkCapacity(); mode == modeBlock && timeToWait > 0 { 94 // We do not wait, we want block directly. 95 trip.err = errors.BetaBackendRateLimitExceeded 96 trip.out <- trip 97 98 l.mu.Unlock() 99 } else { 100 select { 101 // Noop if 'timeToWait' is 0. 102 case <-time.After(timeToWait): 103 case <-trip.req.Context().Done(): 104 // The request was canceled while in the queue. 105 trip.err = trip.req.Context().Err() 106 trip.out <- trip 107 108 // Do not sleep for X canceled requests. 109 continue 110 } 111 112 l.countRequest() 113 114 l.mu.Unlock() 115 116 // Do not wait for the response... 117 go func() { 118 trip.res, trip.err = l.transport.RoundTrip(trip.req) 119 120 if trip.res != nil && trip.res.StatusCode == http.StatusTooManyRequests { 121 trip.err = errors.BetaBackendRateLimitExceeded.With(trip.err) 122 } 123 124 trip.out <- trip 125 }() 126 } 127 } 128 } 129 } 130 131 func (l *Limiter) checkCapacity() (mode int, t time.Duration) { 132 now := time.Now() 133 134 for _, rl := range l.limits { 135 switch rl.window { 136 case windowFixed: 137 // Update current period. 138 multiplicator := ((now.UnixNano() - rl.periodStart.UnixNano()) / int64(time.Nanosecond)) / rl.period.Nanoseconds() 139 if multiplicator > 0 { 140 rl.periodStart = rl.periodStart.Add(time.Duration(rl.period.Nanoseconds() * multiplicator)) 141 rl.count = 0 142 } 143 144 if rl.count >= rl.perPeriod { 145 // Calculate the 'timeToWait'. 146 t = time.Duration((rl.periodStart.Add(rl.period).UnixNano() - now.UnixNano()) / int64(time.Nanosecond)) 147 148 mode = rl.mode 149 } 150 case windowSliding: 151 latest := rl.ringBuffer.get() 152 153 if !latest.IsZero() && latest.Add(rl.period).After(now) { 154 // Calculate the 'timeToWait'. 155 t = time.Duration((latest.Add(rl.period).UnixNano() - now.UnixNano()) / int64(time.Nanosecond)) 156 157 mode = rl.mode 158 } 159 } 160 } 161 162 return 163 } 164 165 // countRequest MUST only be called after checkCapacity() 166 func (l *Limiter) countRequest() { 167 for _, rl := range l.limits { 168 rl.countRequest() 169 } 170 }