code.vegaprotocol.io/vega@v0.79.0/datanode/ratelimit/ratelimit.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package ratelimit
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"math"
    22  	"net"
    23  	"net/http"
    24  	"sync/atomic"
    25  	"time"
    26  
    27  	"code.vegaprotocol.io/vega/datanode/contextutil"
    28  	"code.vegaprotocol.io/vega/logging"
    29  
    30  	"github.com/didip/tollbooth/v7"
    31  	"github.com/didip/tollbooth/v7/libstring"
    32  	"github.com/didip/tollbooth/v7/limiter"
    33  	"github.com/google/uuid"
    34  	"google.golang.org/grpc"
    35  	"google.golang.org/grpc/codes"
    36  	"google.golang.org/grpc/metadata"
    37  	"google.golang.org/grpc/status"
    38  )
    39  
    40  var (
    41  	secret   string
    42  	banMsg   = "temporarily banned for continuing to request while rate limited"
    43  	limitMsg = "api request rate limit exceeded"
    44  )
    45  
    46  // init sets our random per-process secret generated at startup.
    47  //
    48  // If the "X-Rate-Limit-Secret": <secret> is present in GRPC metadata, rate limiting will not be applied.
    49  func init() {
    50  	secret = uuid.New().String()
    51  }
    52  
    53  // WithSecret is a GRPC dial option that adds the "X-Rate-Limit-Secret": <secret> header to all calls.
    54  func WithSecret() grpc.DialOption {
    55  	interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    56  		ctx = metadata.AppendToOutgoingContext(ctx, "RateLimit-Secret", secret)
    57  		return invoker(ctx, method, req, reply, cc, opts...)
    58  	}
    59  	return grpc.WithUnaryInterceptor(interceptor)
    60  }
    61  
    62  type RateLimit struct {
    63  	lmt         *limiter.Limiter
    64  	cfg         atomic.Pointer[Config]
    65  	log         *logging.Logger
    66  	naughtyStep *naughtyStep
    67  }
    68  
    69  func NewFromConfig(cfg *Config, log *logging.Logger) *RateLimit {
    70  	limitOpts := limiter.ExpirableOptions{DefaultExpirationTTL: cfg.TTL.Duration}
    71  	lmt := tollbooth.NewLimiter(cfg.Rate, &limitOpts)
    72  	lmt.SetBurst(cfg.Burst)
    73  
    74  	// The naughty step limiter could have a different rate/burst but it seemed likely to add
    75  	// more confusion than it's worth to the configuration & these should be sensible.
    76  	ns := newNaughtyStep(log, cfg.Rate, cfg.Burst, cfg.BanFor.Duration, cfg.TTL.Duration)
    77  
    78  	r := &RateLimit{
    79  		lmt:         lmt,
    80  		naughtyStep: ns,
    81  		log:         log,
    82  	}
    83  	r.cfg.Store(cfg)
    84  	return r
    85  }
    86  
    87  func (r *RateLimit) ReloadConfig(cfg *Config) {
    88  	r.log.Info("updating rate limit configuration",
    89  		logging.String("old", fmt.Sprintf("%v", r.cfg.Load())),
    90  		logging.String("new", fmt.Sprintf("%v", cfg)))
    91  
    92  	r.cfg.Store(cfg)
    93  	r.lmt.SetBurst(cfg.Burst).
    94  		SetMax(cfg.Rate)
    95  	r.naughtyStep.lmt.SetBurst(cfg.Burst).
    96  		SetMax(cfg.Rate)
    97  	r.naughtyStep.banFor = cfg.BanFor.Duration
    98  }
    99  
   100  func (r *RateLimit) HTTPMiddleware(next http.Handler) http.Handler {
   101  	middle := func(w http.ResponseWriter, req *http.Request) {
   102  		if !r.cfg.Load().Enabled {
   103  			next.ServeHTTP(w, req)
   104  			return
   105  		}
   106  
   107  		ip := r.ipForRequest(req)
   108  
   109  		if r.naughtyStep.isBanned(ip) {
   110  			r.expressDisappointment(w, banMsg, ip, http.StatusForbidden, true)
   111  			return
   112  		}
   113  
   114  		if httpError := tollbooth.LimitByRequest(r.lmt, w, req); httpError != nil {
   115  			r.naughtyStep.smackBottom(ip)
   116  			r.expressDisappointment(w, limitMsg, ip, http.StatusTooManyRequests, false)
   117  			return
   118  		}
   119  
   120  		next.ServeHTTP(w, req)
   121  	}
   122  	return http.HandlerFunc(middle)
   123  }
   124  
   125  func (r *RateLimit) expressDisappointment(w http.ResponseWriter, msg, ip string, status int, banned bool) {
   126  	w.Header().Add("Content-Type", "application/json")
   127  
   128  	if banned {
   129  		expiry := r.naughtyStep.bans[ip]
   130  		remaining := time.Until(expiry).Seconds()
   131  
   132  		w.Header().Add("RateLimit-Retry-After", fmt.Sprintf("%0.f", remaining))
   133  	}
   134  	w.WriteHeader(status)
   135  	_, _ = w.Write([]byte(msg))
   136  }
   137  
   138  func (r *RateLimit) ipForRequest(req *http.Request) string {
   139  	ip := libstring.RemoteIP(r.lmt.GetIPLookups(), r.lmt.GetForwardedForIndexFromBehind(), req)
   140  	return libstring.CanonicalizeIP(ip)
   141  }
   142  
   143  func (r *RateLimit) GRPCInterceptor(
   144  	ctx context.Context,
   145  	req interface{},
   146  	_ *grpc.UnaryServerInfo,
   147  	handler grpc.UnaryHandler,
   148  ) (resp interface{}, err error) {
   149  	if !r.cfg.Load().Enabled {
   150  		return handler(ctx, req)
   151  	}
   152  
   153  	// Check if the client gave the secret in the metadata, if so skip rate limiting
   154  	md, ok := metadata.FromIncomingContext(ctx)
   155  	if ok {
   156  		mdSecrets := md.Get("RateLimit-Secret")
   157  		for _, mdSecret := range mdSecrets {
   158  			if mdSecret == secret {
   159  				return handler(ctx, req)
   160  			}
   161  		}
   162  	}
   163  
   164  	// Fish out IP address from context
   165  	addr, ok := contextutil.RemoteIPAddrFromContext(ctx)
   166  	if !ok {
   167  		// If we don't have an IP we can't rate limit
   168  		return handler(ctx, req)
   169  	}
   170  
   171  	ip, _, err := net.SplitHostPort(addr)
   172  	if err != nil {
   173  		ip = addr
   174  	}
   175  	ip = libstring.CanonicalizeIP(ip)
   176  
   177  	// Check the naughty step
   178  	if r.naughtyStep.isBanned(ip) {
   179  		expiry := r.naughtyStep.bans[ip]
   180  		remaining := time.Until(expiry).Seconds()
   181  
   182  		if err := grpc.SetHeader(ctx, metadata.Pairs("RateLimit-Retry-After", fmt.Sprintf("%0.f", remaining))); err != nil {
   183  			r.log.Error("failed to set header", logging.Error(err))
   184  		}
   185  
   186  		// codes.PermissionDenied is translated to HTTP 403 Forbidden
   187  		return nil, status.Error(codes.PermissionDenied, banMsg)
   188  	}
   189  
   190  	if r.lmt.LimitReached(ip) {
   191  		r.naughtyStep.smackBottom(ip)
   192  		setRateLimitResponseHeaders(ctx, r.log, r.lmt, 0, ip)
   193  		// code.ResourceExhausted is translated to HTTP 429 Too Many Requests
   194  		return nil, status.Error(codes.ResourceExhausted, limitMsg)
   195  	}
   196  
   197  	tokensLeft := r.lmt.Tokens(ip)
   198  	setRateLimitResponseHeaders(ctx, r.log, r.lmt, tokensLeft, ip)
   199  	return handler(ctx, req)
   200  }
   201  
   202  // setRateLimitResponseHeaders configures RateLimit-Limit, RateLimit-Remaining and RateLimit-Reset
   203  // as seen at https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers
   204  func setRateLimitResponseHeaders(ctx context.Context, log *logging.Logger, lmt *limiter.Limiter, tokensLeft int, ip string) {
   205  	for _, h := range []metadata.MD{
   206  		metadata.Pairs("RateLimit-Limit", fmt.Sprintf("%d", int(math.Round(lmt.GetMax())))),
   207  		metadata.Pairs("RateLimit-Reset", "1"),
   208  		metadata.Pairs("RateLimit-Remaining", fmt.Sprintf("%d", tokensLeft)),
   209  		metadata.Pairs("RateLimit-Request-Remote-Addr", ip),
   210  	} {
   211  		if errH := grpc.SetHeader(ctx, h); errH != nil {
   212  			log.Error("failed to set header", logging.Error(errH))
   213  		}
   214  	}
   215  }