github.com/tcnksm/go@v0.0.0-20141208075154-439b32936367/src/net/http/httputil/reverseproxy_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Reverse proxy tests. 6 7 package httputil 8 9 import ( 10 "io/ioutil" 11 "net/http" 12 "net/http/httptest" 13 "net/url" 14 "strings" 15 "testing" 16 "time" 17 ) 18 19 const fakeHopHeader = "X-Fake-Hop-Header-For-Test" 20 21 func init() { 22 hopHeaders = append(hopHeaders, fakeHopHeader) 23 } 24 25 func TestReverseProxy(t *testing.T) { 26 const backendResponse = "I am the backend" 27 const backendStatus = 404 28 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 29 if len(r.TransferEncoding) > 0 { 30 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) 31 } 32 if r.Header.Get("X-Forwarded-For") == "" { 33 t.Errorf("didn't get X-Forwarded-For header") 34 } 35 if c := r.Header.Get("Connection"); c != "" { 36 t.Errorf("handler got Connection header value %q", c) 37 } 38 if c := r.Header.Get("Upgrade"); c != "" { 39 t.Errorf("handler got Upgrade header value %q", c) 40 } 41 if g, e := r.Host, "some-name"; g != e { 42 t.Errorf("backend got Host header %q, want %q", g, e) 43 } 44 w.Header().Set("X-Foo", "bar") 45 w.Header().Set("Upgrade", "foo") 46 w.Header().Set(fakeHopHeader, "foo") 47 w.Header().Add("X-Multi-Value", "foo") 48 w.Header().Add("X-Multi-Value", "bar") 49 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) 50 w.WriteHeader(backendStatus) 51 w.Write([]byte(backendResponse)) 52 })) 53 defer backend.Close() 54 backendURL, err := url.Parse(backend.URL) 55 if err != nil { 56 t.Fatal(err) 57 } 58 proxyHandler := NewSingleHostReverseProxy(backendURL) 59 frontend := httptest.NewServer(proxyHandler) 60 defer frontend.Close() 61 62 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 63 getReq.Host = "some-name" 64 getReq.Header.Set("Connection", "close") 65 getReq.Header.Set("Upgrade", "foo") 66 getReq.Close = true 67 res, err := http.DefaultClient.Do(getReq) 68 if err != nil { 69 t.Fatalf("Get: %v", err) 70 } 71 if g, e := res.StatusCode, backendStatus; g != e { 72 t.Errorf("got res.StatusCode %d; expected %d", g, e) 73 } 74 if g, e := res.Header.Get("X-Foo"), "bar"; g != e { 75 t.Errorf("got X-Foo %q; expected %q", g, e) 76 } 77 if c := res.Header.Get(fakeHopHeader); c != "" { 78 t.Errorf("got %s header value %q", fakeHopHeader, c) 79 } 80 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { 81 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) 82 } 83 if g, e := len(res.Header["Set-Cookie"]), 1; g != e { 84 t.Fatalf("got %d SetCookies, want %d", g, e) 85 } 86 if cookie := res.Cookies()[0]; cookie.Name != "flavor" { 87 t.Errorf("unexpected cookie %q", cookie.Name) 88 } 89 bodyBytes, _ := ioutil.ReadAll(res.Body) 90 if g, e := string(bodyBytes), backendResponse; g != e { 91 t.Errorf("got body %q; expected %q", g, e) 92 } 93 } 94 95 func TestXForwardedFor(t *testing.T) { 96 const prevForwardedFor = "client ip" 97 const backendResponse = "I am the backend" 98 const backendStatus = 404 99 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 100 if r.Header.Get("X-Forwarded-For") == "" { 101 t.Errorf("didn't get X-Forwarded-For header") 102 } 103 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { 104 t.Errorf("X-Forwarded-For didn't contain prior data") 105 } 106 w.WriteHeader(backendStatus) 107 w.Write([]byte(backendResponse)) 108 })) 109 defer backend.Close() 110 backendURL, err := url.Parse(backend.URL) 111 if err != nil { 112 t.Fatal(err) 113 } 114 proxyHandler := NewSingleHostReverseProxy(backendURL) 115 frontend := httptest.NewServer(proxyHandler) 116 defer frontend.Close() 117 118 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 119 getReq.Host = "some-name" 120 getReq.Header.Set("Connection", "close") 121 getReq.Header.Set("X-Forwarded-For", prevForwardedFor) 122 getReq.Close = true 123 res, err := http.DefaultClient.Do(getReq) 124 if err != nil { 125 t.Fatalf("Get: %v", err) 126 } 127 if g, e := res.StatusCode, backendStatus; g != e { 128 t.Errorf("got res.StatusCode %d; expected %d", g, e) 129 } 130 bodyBytes, _ := ioutil.ReadAll(res.Body) 131 if g, e := string(bodyBytes), backendResponse; g != e { 132 t.Errorf("got body %q; expected %q", g, e) 133 } 134 } 135 136 var proxyQueryTests = []struct { 137 baseSuffix string // suffix to add to backend URL 138 reqSuffix string // suffix to add to frontend's request URL 139 want string // what backend should see for final request URL (without ?) 140 }{ 141 {"", "", ""}, 142 {"?sta=tic", "?us=er", "sta=tic&us=er"}, 143 {"", "?us=er", "us=er"}, 144 {"?sta=tic", "", "sta=tic"}, 145 } 146 147 func TestReverseProxyQuery(t *testing.T) { 148 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 149 w.Header().Set("X-Got-Query", r.URL.RawQuery) 150 w.Write([]byte("hi")) 151 })) 152 defer backend.Close() 153 154 for i, tt := range proxyQueryTests { 155 backendURL, err := url.Parse(backend.URL + tt.baseSuffix) 156 if err != nil { 157 t.Fatal(err) 158 } 159 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) 160 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) 161 req.Close = true 162 res, err := http.DefaultClient.Do(req) 163 if err != nil { 164 t.Fatalf("%d. Get: %v", i, err) 165 } 166 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { 167 t.Errorf("%d. got query %q; expected %q", i, g, e) 168 } 169 res.Body.Close() 170 frontend.Close() 171 } 172 } 173 174 func TestReverseProxyFlushInterval(t *testing.T) { 175 const expected = "hi" 176 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 177 w.Write([]byte(expected)) 178 })) 179 defer backend.Close() 180 181 backendURL, err := url.Parse(backend.URL) 182 if err != nil { 183 t.Fatal(err) 184 } 185 186 proxyHandler := NewSingleHostReverseProxy(backendURL) 187 proxyHandler.FlushInterval = time.Microsecond 188 189 done := make(chan bool) 190 onExitFlushLoop = func() { done <- true } 191 defer func() { onExitFlushLoop = nil }() 192 193 frontend := httptest.NewServer(proxyHandler) 194 defer frontend.Close() 195 196 req, _ := http.NewRequest("GET", frontend.URL, nil) 197 req.Close = true 198 res, err := http.DefaultClient.Do(req) 199 if err != nil { 200 t.Fatalf("Get: %v", err) 201 } 202 defer res.Body.Close() 203 if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { 204 t.Errorf("got body %q; expected %q", bodyBytes, expected) 205 } 206 207 select { 208 case <-done: 209 // OK 210 case <-time.After(5 * time.Second): 211 t.Error("maxLatencyWriter flushLoop() never exited") 212 } 213 }