go.fuchsia.dev/infra@v0.0.0-20240507153436-9b593402251b/cmd/gcsproxy/main.go (about)

     1  // Copyright 2019 The Fuchsia Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style license that can
     3  // found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"flag"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"math"
    15  	"net"
    16  	"net/http"
    17  	"os"
    18  	"os/exec"
    19  	"os/signal"
    20  	"path"
    21  	"strings"
    22  	"sync"
    23  	"syscall"
    24  	"time"
    25  
    26  	"golang.org/x/oauth2"
    27  	"golang.org/x/oauth2/google"
    28  	"golang.org/x/time/rate"
    29  )
    30  
    31  const (
    32  	// OAuth2 scope for reading GCS.
    33  	// See https://cloud.google.com/storage/docs/authentication.
    34  	gcsReadOnlyScope = "https://www.googleapis.com/auth/devstorage.read_only"
    35  
    36  	// We want to allow at most `maxNumRequests` per IP to be serviced every
    37  	// `refreshWindowMs` milliseconds.
    38  	// In terms of the token bucket underlying rate.Limiter, this translates to a
    39  	// new token refreshed every `tokenRefreshRate` with a burst size of
    40  	// `burstSize`: this allows for a token pool in which we can check Allow()
    41  	// and Wait() at the desired rates without consuming any tokens reserved for
    42  	// servicing of other requests.
    43  	// See https://godoc.org/golang.org/x/time/rate#Limiter for more details.
    44  	maxNumRequests        = 20
    45  	refreshWindowMs       = 200
    46  	tokenBudgetPerRequest = 2
    47  	tokenRefreshRate      = (refreshWindowMs / (tokenBudgetPerRequest * maxNumRequests)) * time.Millisecond
    48  	tokenBurstSize        = 2 * maxNumRequests
    49  
    50  	// Constants for retrying communication with GCS.
    51  	retryAttempts = 10
    52  
    53  	// finishTimeout is the timeout we allow the subprocess to complete in
    54  	// after we send a SIGTERM.
    55  	finishTimeout = 10 * time.Second
    56  )
    57  
    58  var (
    59  	credentialsFile  string
    60  	port             string
    61  	allowedAddrsFile string
    62  	// Only a var for testability.
    63  	// See https://cloud.google.com/storage/docs/request-endpoints.
    64  	gcsHost = "storage.googleapis.com"
    65  	// Only a var for testability.
    66  	retryBackoff = 100 * time.Millisecond
    67  )
    68  
    69  func usage() {
    70  	fmt.Printf(`gcsproxy [flags] [subcommand]
    71  
    72  Starts a proxy server that forwards requests to GCS with authentication.
    73  If positional arguments are provided, they will be run as subprocess and
    74  the lifetime of the server will be scoped to the lifetime of that process.
    75  `)
    76  }
    77  
    78  func init() {
    79  	flag.Usage = usage
    80  	flag.StringVar(&credentialsFile, "credentials", "", "path to a credentials file in the Google Credentials File format; if none provided, default application credentials will be used.")
    81  	flag.StringVar(&port, "port", "", "port at which the server should listen.")
    82  	flag.StringVar(&allowedAddrsFile, "allowed", "", "a flat JSON list of remote addresses allowed to make requests of the proxy server; if not provided, all addresses will be allowed")
    83  }
    84  
    85  func main() {
    86  	flag.Parse()
    87  
    88  	log.SetPrefix("gcsproxy: ")
    89  	log.SetFlags(log.Lmsgprefix | log.Ltime | log.Lmicroseconds | log.Lshortfile)
    90  
    91  	// For a graceful teardown in the event of a canceling signal.
    92  	ctx, cancel := context.WithCancel(context.Background())
    93  	signals := make(chan os.Signal)
    94  	defer close(signals)
    95  	signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
    96  
    97  	go func() {
    98  		select {
    99  		case <-signals:
   100  			cancel()
   101  		}
   102  	}()
   103  
   104  	if err := execute(ctx, credentialsFile, port, allowedAddrsFile, flag.Args()); err != nil {
   105  		log.Fatal(err)
   106  	}
   107  }
   108  
   109  func runSubcommand(ctx context.Context, subCmd []string) error {
   110  	cmd := exec.Command(subCmd[0], subCmd[1:]...)
   111  	cmd.Stdout = os.Stdout
   112  	cmd.Stderr = os.Stderr
   113  
   114  	// Set a process group ID so we can kill the entire group,
   115  	// meaning the process and any of its children.
   116  	cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
   117  
   118  	// Spin off handler to exit subprocesses cleanly via SIGTERM.
   119  	processDone := make(chan struct{})
   120  	var processMu sync.Mutex
   121  	go func() {
   122  		select {
   123  		case <-processDone:
   124  		case <-ctx.Done():
   125  			// We need to check if the process is nil because it won't exist if
   126  			// it has been SIGKILL'd already.
   127  			processMu.Lock()
   128  			defer processMu.Unlock()
   129  			if cmd.Process != nil {
   130  				if err := cmd.Process.Signal(syscall.SIGTERM); err != nil {
   131  					log.Printf("exited cmd with error %s", err)
   132  				}
   133  				// If the subprocess doesn't complete within the finishTimeout,
   134  				// send a SIGKILL to force it to exit.
   135  				go func() {
   136  					select {
   137  					case <-processDone:
   138  					case <-time.After(finishTimeout):
   139  						log.Printf("killing process %d", cmd.Process.Pid)
   140  						// Negating the process ID means interpret it as a process group ID, so
   141  						// we kill the subprocess and all of its children.
   142  						if err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL); err != nil {
   143  							log.Printf("killed cmd with error %s", err)
   144  						}
   145  					}
   146  				}()
   147  			}
   148  		}
   149  	}()
   150  
   151  	// Ensure that the context still exists before running the subprocess.
   152  	if ctx.Err() != nil {
   153  		log.Print("context exited before starting subprocess")
   154  		return ctx.Err()
   155  	}
   156  
   157  	// We need to make this a critical section because running Start changes
   158  	// cmd.Process, which we attempt to access in the goroutine above. Not locking
   159  	// causes a data race.
   160  	processMu.Lock()
   161  	log.Printf("starting: %s", subCmd)
   162  	err := cmd.Start()
   163  	processMu.Unlock()
   164  	if err != nil {
   165  		close(processDone)
   166  		return err
   167  	}
   168  	// Since we wait for the command to complete even if we send a SIGTERM when the
   169  	// context is canceled, it is up to the underlying command to exit with the
   170  	// proper exit code after handling a SIGTERM.
   171  	err = cmd.Wait()
   172  	close(processDone)
   173  	if err != nil && ctx.Err() != nil {
   174  		// If the subprocess was terminated early and completed without terminating
   175  		// its child processes, kill any remaining processes in the group.
   176  		log.Printf("killing process %d", cmd.Process.Pid)
   177  		syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
   178  		return ctx.Err()
   179  	}
   180  	return err
   181  }
   182  
   183  // Execute starts the proxy server.
   184  func execute(ctx context.Context, credFile, port, addrsFile string, subCmd []string) error {
   185  	if port == "" {
   186  		return fmt.Errorf("-port is required")
   187  	}
   188  
   189  	client, err := httpClient(ctx, credFile)
   190  	if err != nil {
   191  		return err
   192  	}
   193  
   194  	var allowedAddrs []string
   195  	if allowedAddrsFile != "" {
   196  		b, err := os.ReadFile(allowedAddrsFile)
   197  		if err != nil {
   198  			return err
   199  		}
   200  		if err := json.Unmarshal(b, &allowedAddrs); err != nil {
   201  			return err
   202  		}
   203  	}
   204  
   205  	limiters := new(sync.Map)
   206  	for _, addr := range allowedAddrs {
   207  		limiters.Store(addr, newLimiter())
   208  	}
   209  
   210  	outdir := os.Getenv("FUCHSIA_TEST_OUTDIR")
   211  	if outdir == "" {
   212  		log.Printf("GCS Proxy could not find FUCHSIA_TEST_OUTDIR, using %s instead", os.TempDir())
   213  		outdir = os.TempDir()
   214  	}
   215  	metricsFile := path.Join(outdir, "gcsproxy_metrics")
   216  	log.Printf("GCS Proxy will log metrics to %s", metricsFile)
   217  
   218  	redirect := &redirectHandler{
   219  		client:        client,
   220  		restrictAddrs: allowedAddrs != nil,
   221  		limiters:      limiters,
   222  		metricsFile:   metricsFile,
   223  	}
   224  
   225  	mux := http.NewServeMux()
   226  	mux.Handle("/", redirect)
   227  	ln, err := net.Listen("tcp", fmt.Sprintf(":%s", port))
   228  	if err != nil {
   229  		return err
   230  	}
   231  	s := http.Server{Handler: mux}
   232  
   233  	wg := sync.WaitGroup{}
   234  	wg.Add(1)
   235  	go func() {
   236  		defer wg.Done()
   237  		log.Printf("starting a GCS proxy server at %s", ln.Addr().String())
   238  		if err := s.Serve(ln); err != http.ErrServerClosed {
   239  			log.Fatal(err)
   240  		}
   241  	}()
   242  	defer wg.Wait()
   243  
   244  	defer func() {
   245  		log.Printf("Shutting down GCS proxy server at %s", ln.Addr().String())
   246  		if err := s.Shutdown(ctx); err != nil {
   247  			log.Fatal(err)
   248  		}
   249  	}()
   250  
   251  	if len(subCmd) == 0 {
   252  		log.Printf("Press Ctrl-C to abort")
   253  		<-ctx.Done()
   254  	}
   255  	return runSubcommand(ctx, subCmd)
   256  }
   257  
   258  func newLimiter() *rate.Limiter {
   259  	limit := rate.Every(tokenRefreshRate)
   260  	return rate.NewLimiter(limit, tokenBurstSize)
   261  }
   262  
   263  // sleep sleeps until the context is done or d time elapsed.
   264  //
   265  // Returns true if the context was done.
   266  func sleep(ctx context.Context, d time.Duration) bool {
   267  	timer := time.NewTimer(d)
   268  	defer timer.Stop()
   269  
   270  	select {
   271  	case <-ctx.Done():
   272  		return true
   273  	case <-timer.C:
   274  		return false
   275  	}
   276  }
   277  
   278  // RedirectHandler is a simple handler that redirects requests to GCS.
   279  type redirectHandler struct {
   280  	client *http.Client
   281  	// Whether to only serve to addresses present in the limiter map.
   282  	restrictAddrs bool
   283  	// limiters is a map of string: *rate.Limiter.
   284  	limiters *sync.Map
   285  	// metricsFile is a path to a file to write metrics to.
   286  	metricsFile string
   287  }
   288  
   289  func shouldRetry(statusCode int) bool {
   290  	// We suspect we're seeing flaky NotFound errors, so retry in that case. See fxbug.dev/65035.
   291  	return (statusCode == http.StatusNotFound ||
   292  		// Suggested by https://cloud.google.com/storage/docs/request-rate#ramp-up
   293  		statusCode == http.StatusTooManyRequests ||
   294  		statusCode >= http.StatusInternalServerError)
   295  }
   296  
   297  func (h *redirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   298  	if req == nil {
   299  		http.Error(w, "cannot handle nil request", http.StatusInternalServerError) // = 501
   300  		return
   301  	}
   302  	ctx := context.Background()
   303  	reqReceived := time.Now()
   304  
   305  	limiter, ok := h.getLimiter(req.RemoteAddr)
   306  	if !ok {
   307  		http.Error(
   308  			w,
   309  			fmt.Sprintf("address %q is not authorized to make requests", req.RemoteAddr),
   310  			http.StatusForbidden, // = 403
   311  		)
   312  		return
   313  	}
   314  
   315  	startedLimiterWait := time.Now()
   316  	if !limiter.Allow() {
   317  		if err := limiter.Wait(ctx); err != nil {
   318  			http.Error(w, fmt.Sprintf("rate-limiting error: %s", err), http.StatusInternalServerError)
   319  		}
   320  	}
   321  	limiterLatency := time.Since(startedLimiterWait).Milliseconds()
   322  
   323  	req.Host = gcsHost
   324  	req.URL.Host = gcsHost
   325  	req.URL.Scheme = "https"
   326  	// It is an error to set this field in an HTTP client request
   327  	// See https://golang.org/pkg/net/http/#Request.
   328  	req.RequestURI = ""
   329  
   330  	var resp *http.Response
   331  	var err error
   332  	loggedReq := false
   333  	var gcsLatency int64
   334  	var attempt int
   335  	var backoffLatency int64
   336  	for attempt = 0; attempt < retryAttempts; attempt++ {
   337  		startedGCSReq := time.Now()
   338  		resp, err = h.client.Do(req)
   339  		if err == nil {
   340  			gcsLatency += time.Since(startedGCSReq).Milliseconds()
   341  		}
   342  		if err == nil && !shouldRetry(resp.StatusCode) || ctx.Err() != nil || attempt+1 == retryAttempts {
   343  			break
   344  		}
   345  		if !loggedReq {
   346  			var msg string
   347  			if err != nil {
   348  				msg = fmt.Sprintf("unexpected error: %s", err)
   349  			} else if resp != nil {
   350  				msg = fmt.Sprintf("unexpected response: %s", resp.Status)
   351  			} else {
   352  				msg = "both error and response are nil"
   353  			}
   354  			log.Printf("%s. Retrying request. URL: %s Header: %s", msg, req.URL, req.Header)
   355  			loggedReq = true
   356  		}
   357  		backoff := retryBackoff * time.Duration(math.Pow(1.5, float64(attempt)))
   358  		backoffLatency += backoff.Milliseconds()
   359  		if sleep(ctx, backoff) {
   360  			break
   361  		}
   362  	}
   363  	if err != nil {
   364  		http.Error(w, err.Error(), http.StatusInternalServerError)
   365  		return
   366  	} else if resp == nil {
   367  		http.Error(w, "received a nil response", http.StatusInternalServerError)
   368  	}
   369  
   370  	startResponseTime := time.Now()
   371  	for k, v := range resp.Header {
   372  		for _, s := range v {
   373  			w.Header().Add(k, s)
   374  		}
   375  	}
   376  	w.WriteHeader(resp.StatusCode)
   377  
   378  	blobSize, err := io.Copy(w, resp.Body)
   379  	if err != nil {
   380  		http.Error(w, err.Error(), http.StatusInternalServerError)
   381  		return
   382  	}
   383  	metrics := fmt.Sprintf(
   384  		"%s, %s, %d, %d, %d, %d, %d, %d, %d\n",
   385  		reqReceived.String(),
   386  		req.URL.String(),
   387  		blobSize,
   388  		time.Since(reqReceived).Milliseconds(),
   389  		limiterLatency,
   390  		gcsLatency,
   391  		backoffLatency,
   392  		time.Since(startResponseTime).Milliseconds(),
   393  		attempt+1,
   394  	)
   395  	f, err := os.OpenFile(h.metricsFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o755)
   396  	if err != nil {
   397  		log.Printf("failed to open metrics file: %s", err)
   398  		return
   399  	}
   400  	defer f.Close()
   401  	f.WriteString(metrics)
   402  }
   403  
   404  // getLimiter returns the limiter associated with a given address and whether
   405  // the address is allowed to make requests.
   406  func (h *redirectHandler) getLimiter(addr string) (*rate.Limiter, bool) {
   407  	var limiter any
   408  	var ok bool
   409  	if limiter, ok = h.limiters.Load(addr); !ok {
   410  		// While the full address may have been in the `-allowed` list, its
   411  		// hostname may be; in that case, dynamically add a limiter for that address.
   412  		hostname := strings.Split(addr, ":")[0]
   413  		_, hostOK := h.limiters.Load(hostname)
   414  
   415  		if !h.restrictAddrs || hostOK {
   416  			limiter = newLimiter()
   417  			h.limiters.Store(addr, limiter)
   418  			ok = true
   419  		}
   420  	}
   421  	if ok {
   422  		return limiter.(*rate.Limiter), ok
   423  	}
   424  	return nil, false
   425  }
   426  
   427  // Returns an HTTP client with the credentials to read from GCS. If no
   428  // credential file is supplied, then the default application credentials
   429  // will be used.
   430  func httpClient(ctx context.Context, credFile string) (*http.Client, error) {
   431  	var creds *google.Credentials
   432  	var err error
   433  	if credFile == "" {
   434  		creds, err = google.FindDefaultCredentials(ctx, gcsReadOnlyScope)
   435  		if err != nil {
   436  			return nil, fmt.Errorf("failed to find default credentials: %w", err)
   437  		}
   438  	} else {
   439  		contents, err := os.ReadFile(credFile)
   440  		if err != nil {
   441  			return nil, err
   442  		}
   443  		creds, err = google.CredentialsFromJSON(ctx, contents, gcsReadOnlyScope)
   444  		if err != nil {
   445  			return nil, fmt.Errorf("failed to derive the derive credentials from supplied file: %w", err)
   446  		}
   447  	}
   448  	return oauth2.NewClient(ctx, creds.TokenSource), nil
   449  }