go.fuchsia.dev/infra@v0.0.0-20240507153436-9b593402251b/cmd/gcsproxy/main_test.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "net/http" 6 "net/http/httptest" 7 "net/url" 8 "sync" 9 "testing" 10 "time" 11 ) 12 13 func TestServeHTTP(t *testing.T) { 14 const ( 15 headerKey = "K" // needs to be all caps 16 headerVal = "v" 17 ) 18 19 var errorsToReturn []int 20 server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 21 if len(errorsToReturn) > 0 { 22 http.Error(w, "oops", errorsToReturn[0]) 23 errorsToReturn = errorsToReturn[1:] 24 return 25 } 26 w.Header().Add(headerKey, headerVal) 27 fmt.Fprintln(w, "hello there") 28 })) 29 defer server.Close() 30 serverURL, err := url.Parse(server.URL) 31 if err != nil { 32 t.Fatalf("failed to parse test server URL: %s", err) 33 } 34 35 savedGCSHost := gcsHost 36 savedRetryBackoff := retryBackoff 37 defer func() { 38 gcsHost = savedGCSHost 39 retryBackoff = savedRetryBackoff 40 }() 41 gcsHost = serverURL.Host 42 retryBackoff = time.Nanosecond 43 44 tests := []struct { 45 name string 46 restrictAddrs bool 47 serverErrors []int 48 ignoreHeaders bool 49 wantStatus int 50 }{ 51 { 52 name: "success", 53 wantStatus: http.StatusOK, 54 }, 55 { 56 name: "restricted address returns forbidden", 57 restrictAddrs: true, 58 ignoreHeaders: true, 59 wantStatus: http.StatusForbidden, 60 }, 61 { 62 name: "retries on server error", 63 serverErrors: []int{http.StatusNotFound, http.StatusBadGateway, http.StatusTooManyRequests}, 64 wantStatus: http.StatusOK, 65 }, 66 } 67 68 for _, tt := range tests { 69 t.Run(tt.name, func(t *testing.T) { 70 rh := redirectHandler{client: server.Client(), restrictAddrs: tt.restrictAddrs, limiters: new(sync.Map)} 71 respWriter := httptest.NewRecorder() 72 errorsToReturn = tt.serverErrors 73 rh.ServeHTTP(respWriter, httptest.NewRequest("", "http://gcsproxy", nil)) 74 res := respWriter.Result() 75 if res.StatusCode != tt.wantStatus { 76 t.Errorf("rh.Serve() returned status %s, want %d", res.Status, tt.wantStatus) 77 } 78 if len(errorsToReturn) != 0 { 79 t.Errorf("rh.Serve() did not consume all errors. Remaining: %d", errorsToReturn) 80 } 81 if !tt.ignoreHeaders { 82 gotVals, ok := res.Header[headerKey] 83 if !ok || len(gotVals) != 1 || gotVals[0] != headerVal { 84 t.Errorf("res.Header is %s, missing {%s: %s}", res.Header, headerKey, headerVal) 85 } 86 } 87 }) 88 } 89 }