go.fuchsia.dev/infra@v0.0.0-20240507153436-9b593402251b/cmd/gcsproxy/main_test.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"net/url"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  )
    12  
    13  func TestServeHTTP(t *testing.T) {
    14  	const (
    15  		headerKey = "K" // needs to be all caps
    16  		headerVal = "v"
    17  	)
    18  
    19  	var errorsToReturn []int
    20  	server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    21  		if len(errorsToReturn) > 0 {
    22  			http.Error(w, "oops", errorsToReturn[0])
    23  			errorsToReturn = errorsToReturn[1:]
    24  			return
    25  		}
    26  		w.Header().Add(headerKey, headerVal)
    27  		fmt.Fprintln(w, "hello there")
    28  	}))
    29  	defer server.Close()
    30  	serverURL, err := url.Parse(server.URL)
    31  	if err != nil {
    32  		t.Fatalf("failed to parse test server URL: %s", err)
    33  	}
    34  
    35  	savedGCSHost := gcsHost
    36  	savedRetryBackoff := retryBackoff
    37  	defer func() {
    38  		gcsHost = savedGCSHost
    39  		retryBackoff = savedRetryBackoff
    40  	}()
    41  	gcsHost = serverURL.Host
    42  	retryBackoff = time.Nanosecond
    43  
    44  	tests := []struct {
    45  		name          string
    46  		restrictAddrs bool
    47  		serverErrors  []int
    48  		ignoreHeaders bool
    49  		wantStatus    int
    50  	}{
    51  		{
    52  			name:       "success",
    53  			wantStatus: http.StatusOK,
    54  		},
    55  		{
    56  			name:          "restricted address returns forbidden",
    57  			restrictAddrs: true,
    58  			ignoreHeaders: true,
    59  			wantStatus:    http.StatusForbidden,
    60  		},
    61  		{
    62  			name:         "retries on server error",
    63  			serverErrors: []int{http.StatusNotFound, http.StatusBadGateway, http.StatusTooManyRequests},
    64  			wantStatus:   http.StatusOK,
    65  		},
    66  	}
    67  
    68  	for _, tt := range tests {
    69  		t.Run(tt.name, func(t *testing.T) {
    70  			rh := redirectHandler{client: server.Client(), restrictAddrs: tt.restrictAddrs, limiters: new(sync.Map)}
    71  			respWriter := httptest.NewRecorder()
    72  			errorsToReturn = tt.serverErrors
    73  			rh.ServeHTTP(respWriter, httptest.NewRequest("", "http://gcsproxy", nil))
    74  			res := respWriter.Result()
    75  			if res.StatusCode != tt.wantStatus {
    76  				t.Errorf("rh.Serve() returned status %s, want %d", res.Status, tt.wantStatus)
    77  			}
    78  			if len(errorsToReturn) != 0 {
    79  				t.Errorf("rh.Serve() did not consume all errors. Remaining: %d", errorsToReturn)
    80  			}
    81  			if !tt.ignoreHeaders {
    82  				gotVals, ok := res.Header[headerKey]
    83  				if !ok || len(gotVals) != 1 || gotVals[0] != headerVal {
    84  					t.Errorf("res.Header is %s, missing {%s: %s}", res.Header, headerKey, headerVal)
    85  				}
    86  			}
    87  		})
    88  	}
    89  }