go.temporal.io/server@v1.23.0/common/quotas/clocked_rate_limiter.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package quotas
    26  
    27  import (
    28  	"context"
    29  	"errors"
    30  	"fmt"
    31  	"time"
    32  
    33  	"go.temporal.io/server/common/clock"
    34  	"golang.org/x/time/rate"
    35  )
    36  
    37  // ClockedRateLimiter wraps a rate.Limiter with a clock.TimeSource. It is used to ensure that the rate limiter respects
    38  // the time determined by the timeSource.
    39  type ClockedRateLimiter struct {
    40  	rateLimiter *rate.Limiter
    41  	timeSource  clock.TimeSource
    42  }
    43  
    44  var (
    45  	ErrRateLimiterWaitInterrupted                       = errors.New("rate limiter wait interrupted")
    46  	ErrRateLimiterReservationCannotBeMade               = errors.New("rate limiter reservation cannot be made due to insufficient quota")
    47  	ErrRateLimiterReservationWouldExceedContextDeadline = errors.New("rate limiter reservation would exceed context deadline")
    48  )
    49  
    50  func NewClockedRateLimiter(rateLimiter *rate.Limiter, timeSource clock.TimeSource) ClockedRateLimiter {
    51  	return ClockedRateLimiter{
    52  		rateLimiter: rateLimiter,
    53  		timeSource:  timeSource,
    54  	}
    55  }
    56  
    57  func (l ClockedRateLimiter) Allow() bool {
    58  	return l.AllowN(l.timeSource.Now(), 1)
    59  }
    60  
    61  func (l ClockedRateLimiter) AllowN(now time.Time, token int) bool {
    62  	return l.rateLimiter.AllowN(now, token)
    63  }
    64  
    65  // ClockedReservation wraps a rate.Reservation with a clockwork.Clock. It is used to ensure that the reservation
    66  // respects the time determined by the timeSource.
    67  type ClockedReservation struct {
    68  	reservation *rate.Reservation
    69  	timeSource  clock.TimeSource
    70  }
    71  
    72  func (r ClockedReservation) OK() bool {
    73  	return r.reservation.OK()
    74  }
    75  
    76  func (r ClockedReservation) Delay() time.Duration {
    77  	return r.DelayFrom(r.timeSource.Now())
    78  }
    79  
    80  func (r ClockedReservation) DelayFrom(t time.Time) time.Duration {
    81  	return r.reservation.DelayFrom(t)
    82  }
    83  
    84  func (r ClockedReservation) Cancel() {
    85  	r.CancelAt(r.timeSource.Now())
    86  }
    87  
    88  func (r ClockedReservation) CancelAt(t time.Time) {
    89  	r.reservation.CancelAt(t)
    90  }
    91  
    92  func (l ClockedRateLimiter) Reserve() ClockedReservation {
    93  	return l.ReserveN(l.timeSource.Now(), 1)
    94  }
    95  
    96  func (l ClockedRateLimiter) ReserveN(now time.Time, token int) ClockedReservation {
    97  	reservation := l.rateLimiter.ReserveN(now, token)
    98  	return ClockedReservation{reservation, l.timeSource}
    99  }
   100  
   101  func (l ClockedRateLimiter) Wait(ctx context.Context) error {
   102  	return l.WaitN(ctx, 1)
   103  }
   104  
   105  // WaitN is the only method that is different from rate.Limiter. We need to fully reimplement this method because
   106  // the original method uses time.Now(), and does not allow us to pass in a time.Time. Fortunately, it can be built on
   107  // top of ReserveN. However, there are some optimizations that we can make.
   108  func (l ClockedRateLimiter) WaitN(ctx context.Context, token int) error {
   109  	reservation := ClockedReservation{l.rateLimiter.ReserveN(l.timeSource.Now(), token), l.timeSource}
   110  	if !reservation.OK() {
   111  		return fmt.Errorf("%w: WaitN(n=%d)", ErrRateLimiterReservationCannotBeMade, token)
   112  	}
   113  
   114  	waitDuration := reservation.Delay()
   115  
   116  	// Optimization: if the waitDuration is 0, we don't need to start a timer.
   117  	if waitDuration <= 0 {
   118  		return nil
   119  	}
   120  
   121  	// Optimization: if the waitDuration is longer than the context deadline, we can immediately return an error.
   122  	if deadline, ok := ctx.Deadline(); ok {
   123  		if l.timeSource.Now().Add(waitDuration).After(deadline) {
   124  			reservation.Cancel()
   125  			return fmt.Errorf("%w: WaitN(n=%d)", ErrRateLimiterReservationWouldExceedContextDeadline, token)
   126  		}
   127  	}
   128  
   129  	waitExpired := make(chan struct{})
   130  	timer := l.timeSource.AfterFunc(waitDuration, func() {
   131  		close(waitExpired)
   132  	})
   133  	defer timer.Stop()
   134  	select {
   135  	case <-ctx.Done():
   136  		reservation.Cancel()
   137  		return fmt.Errorf("%w: %v", ErrRateLimiterWaitInterrupted, ctx.Err())
   138  	case <-waitExpired:
   139  		return nil
   140  	}
   141  }
   142  
   143  func (l ClockedRateLimiter) SetLimitAt(t time.Time, newLimit rate.Limit) {
   144  	l.rateLimiter.SetLimitAt(t, newLimit)
   145  }
   146  
   147  func (l ClockedRateLimiter) SetBurstAt(t time.Time, newBurst int) {
   148  	l.rateLimiter.SetBurstAt(t, newBurst)
   149  }