github.com/avenga/couper@v1.12.2/handler/ratelimit/limiter.go (about)

     1  package ratelimit
     2  
     3  import (
     4  	"net/http"
     5  	"runtime/debug"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/avenga/couper/errors"
    10  )
    11  
    12  type Limiter struct {
    13  	check     chan *slowTrip
    14  	limits    RateLimits
    15  	mu        sync.RWMutex
    16  	transport http.RoundTripper
    17  }
    18  
    19  type slowTrip struct {
    20  	err    error
    21  	out    chan *slowTrip
    22  	quitCh <-chan struct{}
    23  	req    *http.Request
    24  	res    *http.Response
    25  }
    26  
    27  func NewLimiter(transport http.RoundTripper, limits RateLimits) *Limiter {
    28  	if len(limits) == 0 {
    29  		return nil
    30  	}
    31  
    32  	limiter := &Limiter{
    33  		check:     make(chan *slowTrip),
    34  		limits:    limits,
    35  		transport: transport,
    36  	}
    37  
    38  	for _, rl := range limits {
    39  		// Init the start of a period.
    40  		rl.periodStart = time.Now()
    41  	}
    42  
    43  	go limiter.slowTripper()
    44  
    45  	return limiter
    46  }
    47  
    48  func (l *Limiter) RoundTrip(req *http.Request) (*http.Response, error) {
    49  	outCh := make(chan *slowTrip)
    50  
    51  	trip := &slowTrip{
    52  		out:    outCh,
    53  		quitCh: l.limits[0].quitCh,
    54  		req:    req,
    55  	}
    56  
    57  	select {
    58  	case l.check <- trip:
    59  	case <-req.Context().Done():
    60  		return nil, req.Context().Err()
    61  	}
    62  
    63  	trip = <-outCh
    64  
    65  	return trip.res, trip.err
    66  }
    67  
    68  func (l *Limiter) slowTripper() {
    69  	defer func() {
    70  		if rc := recover(); rc != nil {
    71  			l.limits[0].logger.WithField("panic", string(debug.Stack())).Panic(rc)
    72  		}
    73  	}()
    74  
    75  	for {
    76  		select {
    77  		case <-l.limits[0].quitCh:
    78  			return
    79  		case trip := <-l.check:
    80  			select {
    81  			case <-trip.req.Context().Done():
    82  				// The request was canceled while in the queue.
    83  				trip.err = trip.req.Context().Err()
    84  				trip.out <- trip
    85  
    86  				// Do not sleep for X canceled requests.
    87  				continue
    88  			default:
    89  			}
    90  
    91  			l.mu.Lock()
    92  
    93  			if mode, timeToWait := l.checkCapacity(); mode == modeBlock && timeToWait > 0 {
    94  				// We do not wait, we want block directly.
    95  				trip.err = errors.BetaBackendRateLimitExceeded
    96  				trip.out <- trip
    97  
    98  				l.mu.Unlock()
    99  			} else {
   100  				select {
   101  				// Noop if 'timeToWait' is 0.
   102  				case <-time.After(timeToWait):
   103  				case <-trip.req.Context().Done():
   104  					// The request was canceled while in the queue.
   105  					trip.err = trip.req.Context().Err()
   106  					trip.out <- trip
   107  
   108  					// Do not sleep for X canceled requests.
   109  					continue
   110  				}
   111  
   112  				l.countRequest()
   113  
   114  				l.mu.Unlock()
   115  
   116  				// Do not wait for the response...
   117  				go func() {
   118  					trip.res, trip.err = l.transport.RoundTrip(trip.req)
   119  
   120  					if trip.res != nil && trip.res.StatusCode == http.StatusTooManyRequests {
   121  						trip.err = errors.BetaBackendRateLimitExceeded.With(trip.err)
   122  					}
   123  
   124  					trip.out <- trip
   125  				}()
   126  			}
   127  		}
   128  	}
   129  }
   130  
   131  func (l *Limiter) checkCapacity() (mode int, t time.Duration) {
   132  	now := time.Now()
   133  
   134  	for _, rl := range l.limits {
   135  		switch rl.window {
   136  		case windowFixed:
   137  			// Update current period.
   138  			multiplicator := ((now.UnixNano() - rl.periodStart.UnixNano()) / int64(time.Nanosecond)) / rl.period.Nanoseconds()
   139  			if multiplicator > 0 {
   140  				rl.periodStart = rl.periodStart.Add(time.Duration(rl.period.Nanoseconds() * multiplicator))
   141  				rl.count = 0
   142  			}
   143  
   144  			if rl.count >= rl.perPeriod {
   145  				// Calculate the 'timeToWait'.
   146  				t = time.Duration((rl.periodStart.Add(rl.period).UnixNano() - now.UnixNano()) / int64(time.Nanosecond))
   147  
   148  				mode = rl.mode
   149  			}
   150  		case windowSliding:
   151  			latest := rl.ringBuffer.get()
   152  
   153  			if !latest.IsZero() && latest.Add(rl.period).After(now) {
   154  				// Calculate the 'timeToWait'.
   155  				t = time.Duration((latest.Add(rl.period).UnixNano() - now.UnixNano()) / int64(time.Nanosecond))
   156  
   157  				mode = rl.mode
   158  			}
   159  		}
   160  	}
   161  
   162  	return
   163  }
   164  
   165  // countRequest MUST only be called after checkCapacity()
   166  func (l *Limiter) countRequest() {
   167  	for _, rl := range l.limits {
   168  		rl.countRequest()
   169  	}
   170  }