code.vegaprotocol.io/vega@v0.79.0/datanode/gateway/middleware.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package gateway 17 18 import ( 19 "context" 20 "errors" 21 "fmt" 22 "net" 23 "net/http" 24 "strings" 25 "sync" 26 "time" 27 28 "code.vegaprotocol.io/vega/datanode/contextutil" 29 "code.vegaprotocol.io/vega/datanode/metrics" 30 vfmt "code.vegaprotocol.io/vega/libs/fmt" 31 vhttp "code.vegaprotocol.io/vega/libs/http" 32 "code.vegaprotocol.io/vega/logging" 33 34 "google.golang.org/grpc" 35 "google.golang.org/grpc/codes" 36 "google.golang.org/grpc/status" 37 ) 38 39 var ErrMaxSubscriptionReached = func(ip string, max uint32) error { 40 return fmt.Errorf("max subscriptions count (%v) reached for ip (%s)", max, ip) 41 } 42 43 // RemoteAddrMiddleware is a middleware adding to the current request context the 44 // address of the caller. 45 func RemoteAddrMiddleware(log *logging.Logger, next http.Handler) http.Handler { 46 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 47 ip, err := vhttp.RemoteAddr(r) 48 if err != nil { 49 log.Debug("Failed to get remote address in middleware", 50 logging.String("remote-addr", r.RemoteAddr), 51 logging.String("x-forwarded-for", vfmt.Escape(r.Header.Get("X-Forwarded-For"))), 52 ) 53 } else { 54 r = r.WithContext(contextutil.WithRemoteIPAddr(r.Context(), ip)) 55 } 56 next.ServeHTTP(w, r) 57 }) 58 } 59 60 // MetricCollectionMiddleware records the request and the time taken to service it. 61 func MetricCollectionMiddleware(next http.Handler) http.Handler { 62 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 63 start := time.Now() 64 next.ServeHTTP(w, r) 65 end := time.Now() 66 67 // Update the call count and timings in metrics 68 timetaken := end.Sub(start) 69 70 metrics.APIRequestAndTimeREST(r.Method, r.RequestURI, timetaken.Seconds()) 71 }) 72 } 73 74 // Chain builds the middleware Chain recursively, functions are first class. 75 func Chain(f http.Handler, m ...func(http.Handler) http.Handler) http.Handler { 76 // if our Chain is done, use the original handler func 77 if len(m) == 0 { 78 return f 79 } 80 // otherwise nest the handler funcs 81 return m[0](Chain(f, m[1:cap(m)]...)) 82 } 83 84 func WithAddHeadersMiddleware(next http.Handler) http.Handler { 85 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 86 ctx := r.Context() 87 hijacker, ok := w.(http.Hijacker) 88 if !ok { 89 next.ServeHTTP(w, r) 90 return 91 } 92 iw := &InjectableResponseWriter{ 93 ResponseWriter: w, 94 Hijacker: hijacker, 95 } 96 ctx = context.WithValue(ctx, injectableWriterKey{}, iw) 97 next.ServeHTTP(iw, r.WithContext(ctx)) 98 }) 99 } 100 101 type InjectableResponseWriter struct { 102 http.ResponseWriter 103 http.Hijacker 104 headers http.Header 105 } 106 107 type injectableWriterKey struct{} 108 109 func InjectableWriterFromContext(ctx context.Context) (*InjectableResponseWriter, bool) { 110 if ctx == nil { 111 return nil, false 112 } 113 val := ctx.Value(injectableWriterKey{}) 114 if val == nil { 115 return nil, false 116 } 117 return val.(*InjectableResponseWriter), true 118 } 119 120 func (i *InjectableResponseWriter) Write(data []byte) (int, error) { 121 for k, v := range i.headers { 122 if len(v) > 0 { 123 i.ResponseWriter.Header().Add(k, v[0]) 124 } 125 } 126 return i.ResponseWriter.Write(data) 127 } 128 129 func (i *InjectableResponseWriter) SetHeaders(headers http.Header) { 130 i.headers = headers 131 } 132 133 type SubscriptionRateLimiter struct { 134 log *logging.Logger 135 m map[string]uint32 136 mu sync.Mutex 137 138 MaxSubscriptions uint32 139 } 140 141 func NewSubscriptionRateLimiter( 142 log *logging.Logger, 143 maxSubscriptions uint32, 144 ) *SubscriptionRateLimiter { 145 return &SubscriptionRateLimiter{ 146 log: log, 147 MaxSubscriptions: maxSubscriptions, 148 m: map[string]uint32{}, 149 } 150 } 151 152 func (s *SubscriptionRateLimiter) Inc(ip string) error { 153 s.mu.Lock() 154 defer s.mu.Unlock() 155 cnt := s.m[ip] 156 if cnt == s.MaxSubscriptions { 157 return ErrMaxSubscriptionReached(ip, s.MaxSubscriptions) 158 } 159 s.m[ip] = cnt + 1 160 return nil 161 } 162 163 func (s *SubscriptionRateLimiter) Dec(ip string) { 164 s.mu.Lock() 165 defer s.mu.Unlock() 166 cnt := s.m[ip] 167 s.m[ip] = cnt - 1 168 } 169 170 func (s *SubscriptionRateLimiter) WithSubscriptionRateLimiter(next http.Handler) http.Handler { 171 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 172 // is that a subscription? 173 if _, ok := w.(http.Hijacker); !ok { 174 next.ServeHTTP(w, r) 175 return 176 } 177 178 if ip, err := getIP(r); err != nil { 179 s.log.Debug("couldn't get client ip", logging.Error(err)) 180 } else { 181 if err := s.Inc(ip); err != nil { 182 s.log.Error("client reached max subscription allowed", 183 logging.Error(err)) 184 w.WriteHeader(http.StatusTooManyRequests) 185 w.Write([]byte(err.Error())) 186 // write error 187 return 188 } 189 defer func() { 190 s.Dec(ip) 191 }() 192 } 193 194 next.ServeHTTP(w, r) 195 }) 196 } 197 198 type ipGetter func(ctx context.Context, method string, log *logging.Logger) (string, error) 199 200 func (s *SubscriptionRateLimiter) WithGrpcInterceptor(ipGetterFunc ipGetter) grpc.StreamServerInterceptor { 201 return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 202 addr, err := ipGetterFunc(ss.Context(), info.FullMethod, s.log) 203 if err != nil { 204 return status.Error(codes.PermissionDenied, err.Error()) 205 } 206 if addr == "" { 207 // If we don't have an IP we can't rate limit 208 return handler(srv, ss) 209 } 210 211 ip, _, err := net.SplitHostPort(addr) 212 if err != nil { 213 ip = addr 214 } 215 216 if err := s.Inc(ip); err != nil { 217 s.log.Error("client reached max subscription allowed", 218 logging.Error(err)) 219 // write error 220 return status.Error(codes.ResourceExhausted, "client reached max subscription allowed") 221 } 222 defer func() { 223 s.Dec(ip) 224 }() 225 return handler(srv, ss) 226 } 227 } 228 229 func getIP(r *http.Request) (string, error) { 230 ip := r.Header.Get("X-Real-IP") 231 if net.ParseIP(ip) != nil { 232 return ip, nil 233 } 234 235 ip = r.Header.Get("X-Forward-For") 236 for _, i := range strings.Split(ip, ",") { 237 if net.ParseIP(i) != nil { 238 return i, nil 239 } 240 } 241 242 ip, _, err := net.SplitHostPort(r.RemoteAddr) 243 if err != nil { 244 return "", err 245 } 246 247 if net.ParseIP(ip) != nil { 248 return ip, nil 249 } 250 251 return "", errors.New("no valid ip found") 252 }