github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/common/rpc/rate_limit_interceptor.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"path/filepath"
     6  
     7  	"github.com/rs/zerolog"
     8  	"golang.org/x/time/rate"
     9  	"google.golang.org/grpc"
    10  	"google.golang.org/grpc/codes"
    11  	"google.golang.org/grpc/status"
    12  )
    13  
    14  const defaultRateLimit = 1000 // aggregate default rate limit for all unspecified API calls
    15  const defaultBurst = 100      // default burst limit (calls made at the same time) for an API
    16  
    17  // rateLimiterInterceptor rate limits the
    18  type rateLimiterInterceptor struct {
    19  	log zerolog.Logger
    20  
    21  	// a shared default rate limiter for APIs whose rate limit is not explicitly defined
    22  	defaultLimiter *rate.Limiter
    23  
    24  	// a map of api and its limiter
    25  	methodLimiterMap map[string]*rate.Limiter
    26  }
    27  
    28  // NewRateLimiterInterceptor creates a new rate limiter interceptor with the defined per second rate limits and the
    29  // optional burst limit for each API.
    30  func NewRateLimiterInterceptor(log zerolog.Logger, apiRateLimits map[string]int, apiBurstLimits map[string]int) *rateLimiterInterceptor {
    31  
    32  	defaultLimiter := rate.NewLimiter(rate.Limit(defaultRateLimit), defaultBurst)
    33  	methodLimiterMap := make(map[string]*rate.Limiter, len(apiRateLimits))
    34  
    35  	// read rate limit values for each API and create a limiter for each
    36  	for api, limit := range apiRateLimits {
    37  		// if a burst limit is defined for this api, use that else use the default
    38  		burst := defaultBurst
    39  		if b, ok := apiBurstLimits[api]; ok {
    40  			burst = b
    41  		}
    42  		methodLimiterMap[api] = rate.NewLimiter(rate.Limit(limit), burst)
    43  	}
    44  
    45  	if len(methodLimiterMap) == 0 {
    46  		log.Info().Int("default_rate_limit", defaultRateLimit).Msg("no rate limits specified, using the default limit")
    47  	}
    48  
    49  	return &rateLimiterInterceptor{
    50  		defaultLimiter:   defaultLimiter,
    51  		methodLimiterMap: methodLimiterMap,
    52  		log:              log,
    53  	}
    54  }
    55  
    56  // UnaryServerInterceptor rate limits the given request based on the limits defined when creating the rateLimiterInterceptor
    57  func (interceptor *rateLimiterInterceptor) UnaryServerInterceptor(ctx context.Context,
    58  	req interface{},
    59  	info *grpc.UnaryServerInfo,
    60  	handler grpc.UnaryHandler) (resp interface{}, err error) {
    61  
    62  	// remove the package name (e.g. "/flow.access.AccessAPI/Ping" to "Ping")
    63  	methodName := filepath.Base(info.FullMethod)
    64  
    65  	// look up the limiter
    66  	limiter := interceptor.methodLimiterMap[methodName]
    67  
    68  	// if not found, use the default limiter
    69  	if limiter == nil {
    70  
    71  		interceptor.log.Trace().Str("method", methodName).Msg("rate limit not defined, using default limit")
    72  
    73  		limiter = interceptor.defaultLimiter
    74  	}
    75  
    76  	// check if request within limit
    77  	if !limiter.Allow() {
    78  
    79  		// log the limit violation
    80  		interceptor.log.Trace().
    81  			Str("method", methodName).
    82  			Interface("request", req).
    83  			Float64("limit", float64(limiter.Limit())).
    84  			Msg("rate limit exceeded")
    85  
    86  		// reject the request
    87  		return nil, status.Errorf(codes.ResourceExhausted, "%s rate limit reached, please retry later.",
    88  			info.FullMethod)
    89  	}
    90  
    91  	// call the handler
    92  	h, err := handler(ctx, req)
    93  
    94  	return h, err
    95  }