go.fuchsia.dev/infra@v0.0.0-20240507153436-9b593402251b/cmd/gcsproxy/main.go (about) 1 // Copyright 2019 The Fuchsia Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can 3 // found in the LICENSE file. 4 5 package main 6 7 import ( 8 "context" 9 "encoding/json" 10 "flag" 11 "fmt" 12 "io" 13 "log" 14 "math" 15 "net" 16 "net/http" 17 "os" 18 "os/exec" 19 "os/signal" 20 "path" 21 "strings" 22 "sync" 23 "syscall" 24 "time" 25 26 "golang.org/x/oauth2" 27 "golang.org/x/oauth2/google" 28 "golang.org/x/time/rate" 29 ) 30 31 const ( 32 // OAuth2 scope for reading GCS. 33 // See https://cloud.google.com/storage/docs/authentication. 34 gcsReadOnlyScope = "https://www.googleapis.com/auth/devstorage.read_only" 35 36 // We want to allow at most `maxNumRequests` per IP to be serviced every 37 // `refreshWindowMs` milliseconds. 38 // In terms of the token bucket underlying rate.Limiter, this translates to a 39 // new token refreshed every `tokenRefreshRate` with a burst size of 40 // `burstSize`: this allows for a token pool in which we can check Allow() 41 // and Wait() at the desired rates without consuming any tokens reserved for 42 // servicing of other requests. 43 // See https://godoc.org/golang.org/x/time/rate#Limiter for more details. 44 maxNumRequests = 20 45 refreshWindowMs = 200 46 tokenBudgetPerRequest = 2 47 tokenRefreshRate = (refreshWindowMs / (tokenBudgetPerRequest * maxNumRequests)) * time.Millisecond 48 tokenBurstSize = 2 * maxNumRequests 49 50 // Constants for retrying communication with GCS. 51 retryAttempts = 10 52 53 // finishTimeout is the timeout we allow the subprocess to complete in 54 // after we send a SIGTERM. 55 finishTimeout = 10 * time.Second 56 ) 57 58 var ( 59 credentialsFile string 60 port string 61 allowedAddrsFile string 62 // Only a var for testability. 63 // See https://cloud.google.com/storage/docs/request-endpoints. 64 gcsHost = "storage.googleapis.com" 65 // Only a var for testability. 66 retryBackoff = 100 * time.Millisecond 67 ) 68 69 func usage() { 70 fmt.Printf(`gcsproxy [flags] [subcommand] 71 72 Starts a proxy server that forwards requests to GCS with authentication. 73 If positional arguments are provided, they will be run as subprocess and 74 the lifetime of the server will be scoped to the lifetime of that process. 75 `) 76 } 77 78 func init() { 79 flag.Usage = usage 80 flag.StringVar(&credentialsFile, "credentials", "", "path to a credentials file in the Google Credentials File format; if none provided, default application credentials will be used.") 81 flag.StringVar(&port, "port", "", "port at which the server should listen.") 82 flag.StringVar(&allowedAddrsFile, "allowed", "", "a flat JSON list of remote addresses allowed to make requests of the proxy server; if not provided, all addresses will be allowed") 83 } 84 85 func main() { 86 flag.Parse() 87 88 log.SetPrefix("gcsproxy: ") 89 log.SetFlags(log.Lmsgprefix | log.Ltime | log.Lmicroseconds | log.Lshortfile) 90 91 // For a graceful teardown in the event of a canceling signal. 92 ctx, cancel := context.WithCancel(context.Background()) 93 signals := make(chan os.Signal) 94 defer close(signals) 95 signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) 96 97 go func() { 98 select { 99 case <-signals: 100 cancel() 101 } 102 }() 103 104 if err := execute(ctx, credentialsFile, port, allowedAddrsFile, flag.Args()); err != nil { 105 log.Fatal(err) 106 } 107 } 108 109 func runSubcommand(ctx context.Context, subCmd []string) error { 110 cmd := exec.Command(subCmd[0], subCmd[1:]...) 111 cmd.Stdout = os.Stdout 112 cmd.Stderr = os.Stderr 113 114 // Set a process group ID so we can kill the entire group, 115 // meaning the process and any of its children. 116 cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} 117 118 // Spin off handler to exit subprocesses cleanly via SIGTERM. 119 processDone := make(chan struct{}) 120 var processMu sync.Mutex 121 go func() { 122 select { 123 case <-processDone: 124 case <-ctx.Done(): 125 // We need to check if the process is nil because it won't exist if 126 // it has been SIGKILL'd already. 127 processMu.Lock() 128 defer processMu.Unlock() 129 if cmd.Process != nil { 130 if err := cmd.Process.Signal(syscall.SIGTERM); err != nil { 131 log.Printf("exited cmd with error %s", err) 132 } 133 // If the subprocess doesn't complete within the finishTimeout, 134 // send a SIGKILL to force it to exit. 135 go func() { 136 select { 137 case <-processDone: 138 case <-time.After(finishTimeout): 139 log.Printf("killing process %d", cmd.Process.Pid) 140 // Negating the process ID means interpret it as a process group ID, so 141 // we kill the subprocess and all of its children. 142 if err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL); err != nil { 143 log.Printf("killed cmd with error %s", err) 144 } 145 } 146 }() 147 } 148 } 149 }() 150 151 // Ensure that the context still exists before running the subprocess. 152 if ctx.Err() != nil { 153 log.Print("context exited before starting subprocess") 154 return ctx.Err() 155 } 156 157 // We need to make this a critical section because running Start changes 158 // cmd.Process, which we attempt to access in the goroutine above. Not locking 159 // causes a data race. 160 processMu.Lock() 161 log.Printf("starting: %s", subCmd) 162 err := cmd.Start() 163 processMu.Unlock() 164 if err != nil { 165 close(processDone) 166 return err 167 } 168 // Since we wait for the command to complete even if we send a SIGTERM when the 169 // context is canceled, it is up to the underlying command to exit with the 170 // proper exit code after handling a SIGTERM. 171 err = cmd.Wait() 172 close(processDone) 173 if err != nil && ctx.Err() != nil { 174 // If the subprocess was terminated early and completed without terminating 175 // its child processes, kill any remaining processes in the group. 176 log.Printf("killing process %d", cmd.Process.Pid) 177 syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) 178 return ctx.Err() 179 } 180 return err 181 } 182 183 // Execute starts the proxy server. 184 func execute(ctx context.Context, credFile, port, addrsFile string, subCmd []string) error { 185 if port == "" { 186 return fmt.Errorf("-port is required") 187 } 188 189 client, err := httpClient(ctx, credFile) 190 if err != nil { 191 return err 192 } 193 194 var allowedAddrs []string 195 if allowedAddrsFile != "" { 196 b, err := os.ReadFile(allowedAddrsFile) 197 if err != nil { 198 return err 199 } 200 if err := json.Unmarshal(b, &allowedAddrs); err != nil { 201 return err 202 } 203 } 204 205 limiters := new(sync.Map) 206 for _, addr := range allowedAddrs { 207 limiters.Store(addr, newLimiter()) 208 } 209 210 outdir := os.Getenv("FUCHSIA_TEST_OUTDIR") 211 if outdir == "" { 212 log.Printf("GCS Proxy could not find FUCHSIA_TEST_OUTDIR, using %s instead", os.TempDir()) 213 outdir = os.TempDir() 214 } 215 metricsFile := path.Join(outdir, "gcsproxy_metrics") 216 log.Printf("GCS Proxy will log metrics to %s", metricsFile) 217 218 redirect := &redirectHandler{ 219 client: client, 220 restrictAddrs: allowedAddrs != nil, 221 limiters: limiters, 222 metricsFile: metricsFile, 223 } 224 225 mux := http.NewServeMux() 226 mux.Handle("/", redirect) 227 ln, err := net.Listen("tcp", fmt.Sprintf(":%s", port)) 228 if err != nil { 229 return err 230 } 231 s := http.Server{Handler: mux} 232 233 wg := sync.WaitGroup{} 234 wg.Add(1) 235 go func() { 236 defer wg.Done() 237 log.Printf("starting a GCS proxy server at %s", ln.Addr().String()) 238 if err := s.Serve(ln); err != http.ErrServerClosed { 239 log.Fatal(err) 240 } 241 }() 242 defer wg.Wait() 243 244 defer func() { 245 log.Printf("Shutting down GCS proxy server at %s", ln.Addr().String()) 246 if err := s.Shutdown(ctx); err != nil { 247 log.Fatal(err) 248 } 249 }() 250 251 if len(subCmd) == 0 { 252 log.Printf("Press Ctrl-C to abort") 253 <-ctx.Done() 254 } 255 return runSubcommand(ctx, subCmd) 256 } 257 258 func newLimiter() *rate.Limiter { 259 limit := rate.Every(tokenRefreshRate) 260 return rate.NewLimiter(limit, tokenBurstSize) 261 } 262 263 // sleep sleeps until the context is done or d time elapsed. 264 // 265 // Returns true if the context was done. 266 func sleep(ctx context.Context, d time.Duration) bool { 267 timer := time.NewTimer(d) 268 defer timer.Stop() 269 270 select { 271 case <-ctx.Done(): 272 return true 273 case <-timer.C: 274 return false 275 } 276 } 277 278 // RedirectHandler is a simple handler that redirects requests to GCS. 279 type redirectHandler struct { 280 client *http.Client 281 // Whether to only serve to addresses present in the limiter map. 282 restrictAddrs bool 283 // limiters is a map of string: *rate.Limiter. 284 limiters *sync.Map 285 // metricsFile is a path to a file to write metrics to. 286 metricsFile string 287 } 288 289 func shouldRetry(statusCode int) bool { 290 // We suspect we're seeing flaky NotFound errors, so retry in that case. See fxbug.dev/65035. 291 return (statusCode == http.StatusNotFound || 292 // Suggested by https://cloud.google.com/storage/docs/request-rate#ramp-up 293 statusCode == http.StatusTooManyRequests || 294 statusCode >= http.StatusInternalServerError) 295 } 296 297 func (h *redirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 298 if req == nil { 299 http.Error(w, "cannot handle nil request", http.StatusInternalServerError) // = 501 300 return 301 } 302 ctx := context.Background() 303 reqReceived := time.Now() 304 305 limiter, ok := h.getLimiter(req.RemoteAddr) 306 if !ok { 307 http.Error( 308 w, 309 fmt.Sprintf("address %q is not authorized to make requests", req.RemoteAddr), 310 http.StatusForbidden, // = 403 311 ) 312 return 313 } 314 315 startedLimiterWait := time.Now() 316 if !limiter.Allow() { 317 if err := limiter.Wait(ctx); err != nil { 318 http.Error(w, fmt.Sprintf("rate-limiting error: %s", err), http.StatusInternalServerError) 319 } 320 } 321 limiterLatency := time.Since(startedLimiterWait).Milliseconds() 322 323 req.Host = gcsHost 324 req.URL.Host = gcsHost 325 req.URL.Scheme = "https" 326 // It is an error to set this field in an HTTP client request 327 // See https://golang.org/pkg/net/http/#Request. 328 req.RequestURI = "" 329 330 var resp *http.Response 331 var err error 332 loggedReq := false 333 var gcsLatency int64 334 var attempt int 335 var backoffLatency int64 336 for attempt = 0; attempt < retryAttempts; attempt++ { 337 startedGCSReq := time.Now() 338 resp, err = h.client.Do(req) 339 if err == nil { 340 gcsLatency += time.Since(startedGCSReq).Milliseconds() 341 } 342 if err == nil && !shouldRetry(resp.StatusCode) || ctx.Err() != nil || attempt+1 == retryAttempts { 343 break 344 } 345 if !loggedReq { 346 var msg string 347 if err != nil { 348 msg = fmt.Sprintf("unexpected error: %s", err) 349 } else if resp != nil { 350 msg = fmt.Sprintf("unexpected response: %s", resp.Status) 351 } else { 352 msg = "both error and response are nil" 353 } 354 log.Printf("%s. Retrying request. URL: %s Header: %s", msg, req.URL, req.Header) 355 loggedReq = true 356 } 357 backoff := retryBackoff * time.Duration(math.Pow(1.5, float64(attempt))) 358 backoffLatency += backoff.Milliseconds() 359 if sleep(ctx, backoff) { 360 break 361 } 362 } 363 if err != nil { 364 http.Error(w, err.Error(), http.StatusInternalServerError) 365 return 366 } else if resp == nil { 367 http.Error(w, "received a nil response", http.StatusInternalServerError) 368 } 369 370 startResponseTime := time.Now() 371 for k, v := range resp.Header { 372 for _, s := range v { 373 w.Header().Add(k, s) 374 } 375 } 376 w.WriteHeader(resp.StatusCode) 377 378 blobSize, err := io.Copy(w, resp.Body) 379 if err != nil { 380 http.Error(w, err.Error(), http.StatusInternalServerError) 381 return 382 } 383 metrics := fmt.Sprintf( 384 "%s, %s, %d, %d, %d, %d, %d, %d, %d\n", 385 reqReceived.String(), 386 req.URL.String(), 387 blobSize, 388 time.Since(reqReceived).Milliseconds(), 389 limiterLatency, 390 gcsLatency, 391 backoffLatency, 392 time.Since(startResponseTime).Milliseconds(), 393 attempt+1, 394 ) 395 f, err := os.OpenFile(h.metricsFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o755) 396 if err != nil { 397 log.Printf("failed to open metrics file: %s", err) 398 return 399 } 400 defer f.Close() 401 f.WriteString(metrics) 402 } 403 404 // getLimiter returns the limiter associated with a given address and whether 405 // the address is allowed to make requests. 406 func (h *redirectHandler) getLimiter(addr string) (*rate.Limiter, bool) { 407 var limiter any 408 var ok bool 409 if limiter, ok = h.limiters.Load(addr); !ok { 410 // While the full address may have been in the `-allowed` list, its 411 // hostname may be; in that case, dynamically add a limiter for that address. 412 hostname := strings.Split(addr, ":")[0] 413 _, hostOK := h.limiters.Load(hostname) 414 415 if !h.restrictAddrs || hostOK { 416 limiter = newLimiter() 417 h.limiters.Store(addr, limiter) 418 ok = true 419 } 420 } 421 if ok { 422 return limiter.(*rate.Limiter), ok 423 } 424 return nil, false 425 } 426 427 // Returns an HTTP client with the credentials to read from GCS. If no 428 // credential file is supplied, then the default application credentials 429 // will be used. 430 func httpClient(ctx context.Context, credFile string) (*http.Client, error) { 431 var creds *google.Credentials 432 var err error 433 if credFile == "" { 434 creds, err = google.FindDefaultCredentials(ctx, gcsReadOnlyScope) 435 if err != nil { 436 return nil, fmt.Errorf("failed to find default credentials: %w", err) 437 } 438 } else { 439 contents, err := os.ReadFile(credFile) 440 if err != nil { 441 return nil, err 442 } 443 creds, err = google.CredentialsFromJSON(ctx, contents, gcsReadOnlyScope) 444 if err != nil { 445 return nil, fmt.Errorf("failed to derive the derive credentials from supplied file: %w", err) 446 } 447 } 448 return oauth2.NewClient(ctx, creds.TokenSource), nil 449 }