code.vegaprotocol.io/vega@v0.79.0/datanode/gateway/middleware.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package gateway
    17  
    18  import (
    19  	"context"
    20  	"errors"
    21  	"fmt"
    22  	"net"
    23  	"net/http"
    24  	"strings"
    25  	"sync"
    26  	"time"
    27  
    28  	"code.vegaprotocol.io/vega/datanode/contextutil"
    29  	"code.vegaprotocol.io/vega/datanode/metrics"
    30  	vfmt "code.vegaprotocol.io/vega/libs/fmt"
    31  	vhttp "code.vegaprotocol.io/vega/libs/http"
    32  	"code.vegaprotocol.io/vega/logging"
    33  
    34  	"google.golang.org/grpc"
    35  	"google.golang.org/grpc/codes"
    36  	"google.golang.org/grpc/status"
    37  )
    38  
    39  var ErrMaxSubscriptionReached = func(ip string, max uint32) error {
    40  	return fmt.Errorf("max subscriptions count (%v) reached for ip (%s)", max, ip)
    41  }
    42  
    43  // RemoteAddrMiddleware is a middleware adding to the current request context the
    44  // address of the caller.
    45  func RemoteAddrMiddleware(log *logging.Logger, next http.Handler) http.Handler {
    46  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    47  		ip, err := vhttp.RemoteAddr(r)
    48  		if err != nil {
    49  			log.Debug("Failed to get remote address in middleware",
    50  				logging.String("remote-addr", r.RemoteAddr),
    51  				logging.String("x-forwarded-for", vfmt.Escape(r.Header.Get("X-Forwarded-For"))),
    52  			)
    53  		} else {
    54  			r = r.WithContext(contextutil.WithRemoteIPAddr(r.Context(), ip))
    55  		}
    56  		next.ServeHTTP(w, r)
    57  	})
    58  }
    59  
    60  // MetricCollectionMiddleware records the request and the time taken to service it.
    61  func MetricCollectionMiddleware(next http.Handler) http.Handler {
    62  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    63  		start := time.Now()
    64  		next.ServeHTTP(w, r)
    65  		end := time.Now()
    66  
    67  		// Update the call count and timings in metrics
    68  		timetaken := end.Sub(start)
    69  
    70  		metrics.APIRequestAndTimeREST(r.Method, r.RequestURI, timetaken.Seconds())
    71  	})
    72  }
    73  
    74  // Chain builds the middleware Chain recursively, functions are first class.
    75  func Chain(f http.Handler, m ...func(http.Handler) http.Handler) http.Handler {
    76  	// if our Chain is done, use the original handler func
    77  	if len(m) == 0 {
    78  		return f
    79  	}
    80  	// otherwise nest the handler funcs
    81  	return m[0](Chain(f, m[1:cap(m)]...))
    82  }
    83  
    84  func WithAddHeadersMiddleware(next http.Handler) http.Handler {
    85  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    86  		ctx := r.Context()
    87  		hijacker, ok := w.(http.Hijacker)
    88  		if !ok {
    89  			next.ServeHTTP(w, r)
    90  			return
    91  		}
    92  		iw := &InjectableResponseWriter{
    93  			ResponseWriter: w,
    94  			Hijacker:       hijacker,
    95  		}
    96  		ctx = context.WithValue(ctx, injectableWriterKey{}, iw)
    97  		next.ServeHTTP(iw, r.WithContext(ctx))
    98  	})
    99  }
   100  
   101  type InjectableResponseWriter struct {
   102  	http.ResponseWriter
   103  	http.Hijacker
   104  	headers http.Header
   105  }
   106  
   107  type injectableWriterKey struct{}
   108  
   109  func InjectableWriterFromContext(ctx context.Context) (*InjectableResponseWriter, bool) {
   110  	if ctx == nil {
   111  		return nil, false
   112  	}
   113  	val := ctx.Value(injectableWriterKey{})
   114  	if val == nil {
   115  		return nil, false
   116  	}
   117  	return val.(*InjectableResponseWriter), true
   118  }
   119  
   120  func (i *InjectableResponseWriter) Write(data []byte) (int, error) {
   121  	for k, v := range i.headers {
   122  		if len(v) > 0 {
   123  			i.ResponseWriter.Header().Add(k, v[0])
   124  		}
   125  	}
   126  	return i.ResponseWriter.Write(data)
   127  }
   128  
   129  func (i *InjectableResponseWriter) SetHeaders(headers http.Header) {
   130  	i.headers = headers
   131  }
   132  
   133  type SubscriptionRateLimiter struct {
   134  	log *logging.Logger
   135  	m   map[string]uint32
   136  	mu  sync.Mutex
   137  
   138  	MaxSubscriptions uint32
   139  }
   140  
   141  func NewSubscriptionRateLimiter(
   142  	log *logging.Logger,
   143  	maxSubscriptions uint32,
   144  ) *SubscriptionRateLimiter {
   145  	return &SubscriptionRateLimiter{
   146  		log:              log,
   147  		MaxSubscriptions: maxSubscriptions,
   148  		m:                map[string]uint32{},
   149  	}
   150  }
   151  
   152  func (s *SubscriptionRateLimiter) Inc(ip string) error {
   153  	s.mu.Lock()
   154  	defer s.mu.Unlock()
   155  	cnt := s.m[ip]
   156  	if cnt == s.MaxSubscriptions {
   157  		return ErrMaxSubscriptionReached(ip, s.MaxSubscriptions)
   158  	}
   159  	s.m[ip] = cnt + 1
   160  	return nil
   161  }
   162  
   163  func (s *SubscriptionRateLimiter) Dec(ip string) {
   164  	s.mu.Lock()
   165  	defer s.mu.Unlock()
   166  	cnt := s.m[ip]
   167  	s.m[ip] = cnt - 1
   168  }
   169  
   170  func (s *SubscriptionRateLimiter) WithSubscriptionRateLimiter(next http.Handler) http.Handler {
   171  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   172  		// is that a subscription?
   173  		if _, ok := w.(http.Hijacker); !ok {
   174  			next.ServeHTTP(w, r)
   175  			return
   176  		}
   177  
   178  		if ip, err := getIP(r); err != nil {
   179  			s.log.Debug("couldn't get client ip", logging.Error(err))
   180  		} else {
   181  			if err := s.Inc(ip); err != nil {
   182  				s.log.Error("client reached max subscription allowed",
   183  					logging.Error(err))
   184  				w.WriteHeader(http.StatusTooManyRequests)
   185  				w.Write([]byte(err.Error()))
   186  				// write error
   187  				return
   188  			}
   189  			defer func() {
   190  				s.Dec(ip)
   191  			}()
   192  		}
   193  
   194  		next.ServeHTTP(w, r)
   195  	})
   196  }
   197  
   198  type ipGetter func(ctx context.Context, method string, log *logging.Logger) (string, error)
   199  
   200  func (s *SubscriptionRateLimiter) WithGrpcInterceptor(ipGetterFunc ipGetter) grpc.StreamServerInterceptor {
   201  	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   202  		addr, err := ipGetterFunc(ss.Context(), info.FullMethod, s.log)
   203  		if err != nil {
   204  			return status.Error(codes.PermissionDenied, err.Error())
   205  		}
   206  		if addr == "" {
   207  			// If we don't have an IP we can't rate limit
   208  			return handler(srv, ss)
   209  		}
   210  
   211  		ip, _, err := net.SplitHostPort(addr)
   212  		if err != nil {
   213  			ip = addr
   214  		}
   215  
   216  		if err := s.Inc(ip); err != nil {
   217  			s.log.Error("client reached max subscription allowed",
   218  				logging.Error(err))
   219  			// write error
   220  			return status.Error(codes.ResourceExhausted, "client reached max subscription allowed")
   221  		}
   222  		defer func() {
   223  			s.Dec(ip)
   224  		}()
   225  		return handler(srv, ss)
   226  	}
   227  }
   228  
   229  func getIP(r *http.Request) (string, error) {
   230  	ip := r.Header.Get("X-Real-IP")
   231  	if net.ParseIP(ip) != nil {
   232  		return ip, nil
   233  	}
   234  
   235  	ip = r.Header.Get("X-Forward-For")
   236  	for _, i := range strings.Split(ip, ",") {
   237  		if net.ParseIP(i) != nil {
   238  			return i, nil
   239  		}
   240  	}
   241  
   242  	ip, _, err := net.SplitHostPort(r.RemoteAddr)
   243  	if err != nil {
   244  		return "", err
   245  	}
   246  
   247  	if net.ParseIP(ip) != nil {
   248  		return ip, nil
   249  	}
   250  
   251  	return "", errors.New("no valid ip found")
   252  }