github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/common/rpc/rate_limit_interceptor.go (about) 1 package rpc 2 3 import ( 4 "context" 5 "path/filepath" 6 7 "github.com/rs/zerolog" 8 "golang.org/x/time/rate" 9 "google.golang.org/grpc" 10 "google.golang.org/grpc/codes" 11 "google.golang.org/grpc/status" 12 ) 13 14 const defaultRateLimit = 1000 // aggregate default rate limit for all unspecified API calls 15 const defaultBurst = 100 // default burst limit (calls made at the same time) for an API 16 17 // rateLimiterInterceptor rate limits the 18 type rateLimiterInterceptor struct { 19 log zerolog.Logger 20 21 // a shared default rate limiter for APIs whose rate limit is not explicitly defined 22 defaultLimiter *rate.Limiter 23 24 // a map of api and its limiter 25 methodLimiterMap map[string]*rate.Limiter 26 } 27 28 // NewRateLimiterInterceptor creates a new rate limiter interceptor with the defined per second rate limits and the 29 // optional burst limit for each API. 30 func NewRateLimiterInterceptor(log zerolog.Logger, apiRateLimits map[string]int, apiBurstLimits map[string]int) *rateLimiterInterceptor { 31 32 defaultLimiter := rate.NewLimiter(rate.Limit(defaultRateLimit), defaultBurst) 33 methodLimiterMap := make(map[string]*rate.Limiter, len(apiRateLimits)) 34 35 // read rate limit values for each API and create a limiter for each 36 for api, limit := range apiRateLimits { 37 // if a burst limit is defined for this api, use that else use the default 38 burst := defaultBurst 39 if b, ok := apiBurstLimits[api]; ok { 40 burst = b 41 } 42 methodLimiterMap[api] = rate.NewLimiter(rate.Limit(limit), burst) 43 } 44 45 if len(methodLimiterMap) == 0 { 46 log.Info().Int("default_rate_limit", defaultRateLimit).Msg("no rate limits specified, using the default limit") 47 } 48 49 return &rateLimiterInterceptor{ 50 defaultLimiter: defaultLimiter, 51 methodLimiterMap: methodLimiterMap, 52 log: log, 53 } 54 } 55 56 // UnaryServerInterceptor rate limits the given request based on the limits defined when creating the rateLimiterInterceptor 57 func (interceptor *rateLimiterInterceptor) UnaryServerInterceptor(ctx context.Context, 58 req interface{}, 59 info *grpc.UnaryServerInfo, 60 handler grpc.UnaryHandler) (resp interface{}, err error) { 61 62 // remove the package name (e.g. "/flow.access.AccessAPI/Ping" to "Ping") 63 methodName := filepath.Base(info.FullMethod) 64 65 // look up the limiter 66 limiter := interceptor.methodLimiterMap[methodName] 67 68 // if not found, use the default limiter 69 if limiter == nil { 70 71 interceptor.log.Trace().Str("method", methodName).Msg("rate limit not defined, using default limit") 72 73 limiter = interceptor.defaultLimiter 74 } 75 76 // check if request within limit 77 if !limiter.Allow() { 78 79 // log the limit violation 80 interceptor.log.Trace(). 81 Str("method", methodName). 82 Interface("request", req). 83 Float64("limit", float64(limiter.Limit())). 84 Msg("rate limit exceeded") 85 86 // reject the request 87 return nil, status.Errorf(codes.ResourceExhausted, "%s rate limit reached, please retry later.", 88 info.FullMethod) 89 } 90 91 // call the handler 92 h, err := handler(ctx, req) 93 94 return h, err 95 }