github.com/m10x/go/src@v0.0.0-20220112094212-ba61592315da/net/http/httputil/reverseproxy_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Reverse proxy tests. 6 7 package httputil 8 9 import ( 10 "bufio" 11 "bytes" 12 "context" 13 "errors" 14 "fmt" 15 "io" 16 "log" 17 "net/http" 18 "net/http/httptest" 19 "net/http/internal/ascii" 20 "net/url" 21 "os" 22 "reflect" 23 "sort" 24 "strconv" 25 "strings" 26 "sync" 27 "testing" 28 "time" 29 ) 30 31 const fakeHopHeader = "X-Fake-Hop-Header-For-Test" 32 33 func init() { 34 inOurTests = true 35 hopHeaders = append(hopHeaders, fakeHopHeader) 36 } 37 38 func TestReverseProxy(t *testing.T) { 39 const backendResponse = "I am the backend" 40 const backendStatus = 404 41 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 42 if r.Method == "GET" && r.FormValue("mode") == "hangup" { 43 c, _, _ := w.(http.Hijacker).Hijack() 44 c.Close() 45 return 46 } 47 if len(r.TransferEncoding) > 0 { 48 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) 49 } 50 if r.Header.Get("X-Forwarded-For") == "" { 51 t.Errorf("didn't get X-Forwarded-For header") 52 } 53 if c := r.Header.Get("Connection"); c != "" { 54 t.Errorf("handler got Connection header value %q", c) 55 } 56 if c := r.Header.Get("Te"); c != "trailers" { 57 t.Errorf("handler got Te header value %q; want 'trailers'", c) 58 } 59 if c := r.Header.Get("Upgrade"); c != "" { 60 t.Errorf("handler got Upgrade header value %q", c) 61 } 62 if c := r.Header.Get("Proxy-Connection"); c != "" { 63 t.Errorf("handler got Proxy-Connection header value %q", c) 64 } 65 if g, e := r.Host, "some-name"; g != e { 66 t.Errorf("backend got Host header %q, want %q", g, e) 67 } 68 w.Header().Set("Trailers", "not a special header field name") 69 w.Header().Set("Trailer", "X-Trailer") 70 w.Header().Set("X-Foo", "bar") 71 w.Header().Set("Upgrade", "foo") 72 w.Header().Set(fakeHopHeader, "foo") 73 w.Header().Add("X-Multi-Value", "foo") 74 w.Header().Add("X-Multi-Value", "bar") 75 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) 76 w.WriteHeader(backendStatus) 77 w.Write([]byte(backendResponse)) 78 w.Header().Set("X-Trailer", "trailer_value") 79 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 80 })) 81 defer backend.Close() 82 backendURL, err := url.Parse(backend.URL) 83 if err != nil { 84 t.Fatal(err) 85 } 86 proxyHandler := NewSingleHostReverseProxy(backendURL) 87 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 88 frontend := httptest.NewServer(proxyHandler) 89 defer frontend.Close() 90 frontendClient := frontend.Client() 91 92 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 93 getReq.Host = "some-name" 94 getReq.Header.Set("Connection", "close, TE") 95 getReq.Header.Add("Te", "foo") 96 getReq.Header.Add("Te", "bar, trailers") 97 getReq.Header.Set("Proxy-Connection", "should be deleted") 98 getReq.Header.Set("Upgrade", "foo") 99 getReq.Close = true 100 res, err := frontendClient.Do(getReq) 101 if err != nil { 102 t.Fatalf("Get: %v", err) 103 } 104 if g, e := res.StatusCode, backendStatus; g != e { 105 t.Errorf("got res.StatusCode %d; expected %d", g, e) 106 } 107 if g, e := res.Header.Get("X-Foo"), "bar"; g != e { 108 t.Errorf("got X-Foo %q; expected %q", g, e) 109 } 110 if c := res.Header.Get(fakeHopHeader); c != "" { 111 t.Errorf("got %s header value %q", fakeHopHeader, c) 112 } 113 if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { 114 t.Errorf("header Trailers = %q; want %q", g, e) 115 } 116 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { 117 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) 118 } 119 if g, e := len(res.Header["Set-Cookie"]), 1; g != e { 120 t.Fatalf("got %d SetCookies, want %d", g, e) 121 } 122 if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { 123 t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) 124 } 125 if cookie := res.Cookies()[0]; cookie.Name != "flavor" { 126 t.Errorf("unexpected cookie %q", cookie.Name) 127 } 128 bodyBytes, _ := io.ReadAll(res.Body) 129 if g, e := string(bodyBytes), backendResponse; g != e { 130 t.Errorf("got body %q; expected %q", g, e) 131 } 132 if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { 133 t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) 134 } 135 if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { 136 t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) 137 } 138 139 // Test that a backend failing to be reached or one which doesn't return 140 // a response results in a StatusBadGateway. 141 getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) 142 getReq.Close = true 143 res, err = frontendClient.Do(getReq) 144 if err != nil { 145 t.Fatal(err) 146 } 147 res.Body.Close() 148 if res.StatusCode != http.StatusBadGateway { 149 t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) 150 } 151 152 } 153 154 // Issue 16875: remove any proxied headers mentioned in the "Connection" 155 // header value. 156 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { 157 const fakeConnectionToken = "X-Fake-Connection-Token" 158 const backendResponse = "I am the backend" 159 160 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 161 // in the Request's Connection header. 162 const someConnHeader = "X-Some-Conn-Header" 163 164 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 165 if c := r.Header.Get("Connection"); c != "" { 166 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 167 } 168 if c := r.Header.Get(fakeConnectionToken); c != "" { 169 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 170 } 171 if c := r.Header.Get(someConnHeader); c != "" { 172 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 173 } 174 w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) 175 w.Header().Add("Connection", someConnHeader) 176 w.Header().Set(someConnHeader, "should be deleted") 177 w.Header().Set(fakeConnectionToken, "should be deleted") 178 io.WriteString(w, backendResponse) 179 })) 180 defer backend.Close() 181 backendURL, err := url.Parse(backend.URL) 182 if err != nil { 183 t.Fatal(err) 184 } 185 proxyHandler := NewSingleHostReverseProxy(backendURL) 186 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 187 proxyHandler.ServeHTTP(w, r) 188 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 189 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 190 } 191 if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { 192 t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") 193 } 194 c := r.Header["Connection"] 195 var cf []string 196 for _, f := range c { 197 for _, sf := range strings.Split(f, ",") { 198 if sf = strings.TrimSpace(sf); sf != "" { 199 cf = append(cf, sf) 200 } 201 } 202 } 203 sort.Strings(cf) 204 expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} 205 sort.Strings(expectedValues) 206 if !reflect.DeepEqual(cf, expectedValues) { 207 t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) 208 } 209 })) 210 defer frontend.Close() 211 212 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 213 getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) 214 getReq.Header.Add("Connection", someConnHeader) 215 getReq.Header.Set(someConnHeader, "should be deleted") 216 getReq.Header.Set(fakeConnectionToken, "should be deleted") 217 res, err := frontend.Client().Do(getReq) 218 if err != nil { 219 t.Fatalf("Get: %v", err) 220 } 221 defer res.Body.Close() 222 bodyBytes, err := io.ReadAll(res.Body) 223 if err != nil { 224 t.Fatalf("reading body: %v", err) 225 } 226 if got, want := string(bodyBytes), backendResponse; got != want { 227 t.Errorf("got body %q; want %q", got, want) 228 } 229 if c := res.Header.Get("Connection"); c != "" { 230 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 231 } 232 if c := res.Header.Get(someConnHeader); c != "" { 233 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 234 } 235 if c := res.Header.Get(fakeConnectionToken); c != "" { 236 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 237 } 238 } 239 240 func TestReverseProxyStripEmptyConnection(t *testing.T) { 241 // See Issue 46313. 242 const backendResponse = "I am the backend" 243 244 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 245 // in the Request's Connection header. 246 const someConnHeader = "X-Some-Conn-Header" 247 248 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 249 if c := r.Header.Values("Connection"); len(c) != 0 { 250 t.Errorf("handler got header %q = %v; want empty", "Connection", c) 251 } 252 if c := r.Header.Get(someConnHeader); c != "" { 253 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 254 } 255 w.Header().Add("Connection", "") 256 w.Header().Add("Connection", someConnHeader) 257 w.Header().Set(someConnHeader, "should be deleted") 258 io.WriteString(w, backendResponse) 259 })) 260 defer backend.Close() 261 backendURL, err := url.Parse(backend.URL) 262 if err != nil { 263 t.Fatal(err) 264 } 265 proxyHandler := NewSingleHostReverseProxy(backendURL) 266 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 267 proxyHandler.ServeHTTP(w, r) 268 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 269 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 270 } 271 })) 272 defer frontend.Close() 273 274 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 275 getReq.Header.Add("Connection", "") 276 getReq.Header.Add("Connection", someConnHeader) 277 getReq.Header.Set(someConnHeader, "should be deleted") 278 res, err := frontend.Client().Do(getReq) 279 if err != nil { 280 t.Fatalf("Get: %v", err) 281 } 282 defer res.Body.Close() 283 bodyBytes, err := io.ReadAll(res.Body) 284 if err != nil { 285 t.Fatalf("reading body: %v", err) 286 } 287 if got, want := string(bodyBytes), backendResponse; got != want { 288 t.Errorf("got body %q; want %q", got, want) 289 } 290 if c := res.Header.Get("Connection"); c != "" { 291 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 292 } 293 if c := res.Header.Get(someConnHeader); c != "" { 294 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 295 } 296 } 297 298 func TestXForwardedFor(t *testing.T) { 299 const prevForwardedFor = "client ip" 300 const backendResponse = "I am the backend" 301 const backendStatus = 404 302 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 303 if r.Header.Get("X-Forwarded-For") == "" { 304 t.Errorf("didn't get X-Forwarded-For header") 305 } 306 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { 307 t.Errorf("X-Forwarded-For didn't contain prior data") 308 } 309 w.WriteHeader(backendStatus) 310 w.Write([]byte(backendResponse)) 311 })) 312 defer backend.Close() 313 backendURL, err := url.Parse(backend.URL) 314 if err != nil { 315 t.Fatal(err) 316 } 317 proxyHandler := NewSingleHostReverseProxy(backendURL) 318 frontend := httptest.NewServer(proxyHandler) 319 defer frontend.Close() 320 321 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 322 getReq.Host = "some-name" 323 getReq.Header.Set("Connection", "close") 324 getReq.Header.Set("X-Forwarded-For", prevForwardedFor) 325 getReq.Close = true 326 res, err := frontend.Client().Do(getReq) 327 if err != nil { 328 t.Fatalf("Get: %v", err) 329 } 330 if g, e := res.StatusCode, backendStatus; g != e { 331 t.Errorf("got res.StatusCode %d; expected %d", g, e) 332 } 333 bodyBytes, _ := io.ReadAll(res.Body) 334 if g, e := string(bodyBytes), backendResponse; g != e { 335 t.Errorf("got body %q; expected %q", g, e) 336 } 337 } 338 339 // Issue 38079: don't append to X-Forwarded-For if it's present but nil 340 func TestXForwardedFor_Omit(t *testing.T) { 341 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 342 if v := r.Header.Get("X-Forwarded-For"); v != "" { 343 t.Errorf("got X-Forwarded-For header: %q", v) 344 } 345 w.Write([]byte("hi")) 346 })) 347 defer backend.Close() 348 backendURL, err := url.Parse(backend.URL) 349 if err != nil { 350 t.Fatal(err) 351 } 352 proxyHandler := NewSingleHostReverseProxy(backendURL) 353 frontend := httptest.NewServer(proxyHandler) 354 defer frontend.Close() 355 356 oldDirector := proxyHandler.Director 357 proxyHandler.Director = func(r *http.Request) { 358 r.Header["X-Forwarded-For"] = nil 359 oldDirector(r) 360 } 361 362 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 363 getReq.Host = "some-name" 364 getReq.Close = true 365 res, err := frontend.Client().Do(getReq) 366 if err != nil { 367 t.Fatalf("Get: %v", err) 368 } 369 res.Body.Close() 370 } 371 372 var proxyQueryTests = []struct { 373 baseSuffix string // suffix to add to backend URL 374 reqSuffix string // suffix to add to frontend's request URL 375 want string // what backend should see for final request URL (without ?) 376 }{ 377 {"", "", ""}, 378 {"?sta=tic", "?us=er", "sta=tic&us=er"}, 379 {"", "?us=er", "us=er"}, 380 {"?sta=tic", "", "sta=tic"}, 381 } 382 383 func TestReverseProxyQuery(t *testing.T) { 384 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 385 w.Header().Set("X-Got-Query", r.URL.RawQuery) 386 w.Write([]byte("hi")) 387 })) 388 defer backend.Close() 389 390 for i, tt := range proxyQueryTests { 391 backendURL, err := url.Parse(backend.URL + tt.baseSuffix) 392 if err != nil { 393 t.Fatal(err) 394 } 395 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) 396 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) 397 req.Close = true 398 res, err := frontend.Client().Do(req) 399 if err != nil { 400 t.Fatalf("%d. Get: %v", i, err) 401 } 402 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { 403 t.Errorf("%d. got query %q; expected %q", i, g, e) 404 } 405 res.Body.Close() 406 frontend.Close() 407 } 408 } 409 410 func TestReverseProxyFlushInterval(t *testing.T) { 411 const expected = "hi" 412 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 413 w.Write([]byte(expected)) 414 })) 415 defer backend.Close() 416 417 backendURL, err := url.Parse(backend.URL) 418 if err != nil { 419 t.Fatal(err) 420 } 421 422 proxyHandler := NewSingleHostReverseProxy(backendURL) 423 proxyHandler.FlushInterval = time.Microsecond 424 425 frontend := httptest.NewServer(proxyHandler) 426 defer frontend.Close() 427 428 req, _ := http.NewRequest("GET", frontend.URL, nil) 429 req.Close = true 430 res, err := frontend.Client().Do(req) 431 if err != nil { 432 t.Fatalf("Get: %v", err) 433 } 434 defer res.Body.Close() 435 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { 436 t.Errorf("got body %q; expected %q", bodyBytes, expected) 437 } 438 } 439 440 func TestReverseProxyFlushIntervalHeaders(t *testing.T) { 441 const expected = "hi" 442 stopCh := make(chan struct{}) 443 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 444 w.Header().Add("MyHeader", expected) 445 w.WriteHeader(200) 446 w.(http.Flusher).Flush() 447 <-stopCh 448 })) 449 defer backend.Close() 450 defer close(stopCh) 451 452 backendURL, err := url.Parse(backend.URL) 453 if err != nil { 454 t.Fatal(err) 455 } 456 457 proxyHandler := NewSingleHostReverseProxy(backendURL) 458 proxyHandler.FlushInterval = time.Microsecond 459 460 frontend := httptest.NewServer(proxyHandler) 461 defer frontend.Close() 462 463 req, _ := http.NewRequest("GET", frontend.URL, nil) 464 req.Close = true 465 466 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) 467 defer cancel() 468 req = req.WithContext(ctx) 469 470 res, err := frontend.Client().Do(req) 471 if err != nil { 472 t.Fatalf("Get: %v", err) 473 } 474 defer res.Body.Close() 475 476 if res.Header.Get("MyHeader") != expected { 477 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) 478 } 479 } 480 481 func TestReverseProxyCancellation(t *testing.T) { 482 const backendResponse = "I am the backend" 483 484 reqInFlight := make(chan struct{}) 485 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 486 close(reqInFlight) // cause the client to cancel its request 487 488 select { 489 case <-time.After(10 * time.Second): 490 // Note: this should only happen in broken implementations, and the 491 // closenotify case should be instantaneous. 492 t.Error("Handler never saw CloseNotify") 493 return 494 case <-w.(http.CloseNotifier).CloseNotify(): 495 } 496 497 w.WriteHeader(http.StatusOK) 498 w.Write([]byte(backendResponse)) 499 })) 500 501 defer backend.Close() 502 503 backend.Config.ErrorLog = log.New(io.Discard, "", 0) 504 505 backendURL, err := url.Parse(backend.URL) 506 if err != nil { 507 t.Fatal(err) 508 } 509 510 proxyHandler := NewSingleHostReverseProxy(backendURL) 511 512 // Discards errors of the form: 513 // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection 514 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) 515 516 frontend := httptest.NewServer(proxyHandler) 517 defer frontend.Close() 518 frontendClient := frontend.Client() 519 520 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 521 go func() { 522 <-reqInFlight 523 frontendClient.Transport.(*http.Transport).CancelRequest(getReq) 524 }() 525 res, err := frontendClient.Do(getReq) 526 if res != nil { 527 t.Errorf("got response %v; want nil", res.Status) 528 } 529 if err == nil { 530 // This should be an error like: 531 // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079: 532 // use of closed network connection 533 t.Error("Server.Client().Do() returned nil error; want non-nil error") 534 } 535 } 536 537 func req(t *testing.T, v string) *http.Request { 538 req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) 539 if err != nil { 540 t.Fatal(err) 541 } 542 return req 543 } 544 545 // Issue 12344 546 func TestNilBody(t *testing.T) { 547 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 548 w.Write([]byte("hi")) 549 })) 550 defer backend.Close() 551 552 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 553 backURL, _ := url.Parse(backend.URL) 554 rp := NewSingleHostReverseProxy(backURL) 555 r := req(t, "GET / HTTP/1.0\r\n\r\n") 556 r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working 557 rp.ServeHTTP(w, r) 558 })) 559 defer frontend.Close() 560 561 res, err := http.Get(frontend.URL) 562 if err != nil { 563 t.Fatal(err) 564 } 565 defer res.Body.Close() 566 slurp, err := io.ReadAll(res.Body) 567 if err != nil { 568 t.Fatal(err) 569 } 570 if string(slurp) != "hi" { 571 t.Errorf("Got %q; want %q", slurp, "hi") 572 } 573 } 574 575 // Issue 15524 576 func TestUserAgentHeader(t *testing.T) { 577 const explicitUA = "explicit UA" 578 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 579 if r.URL.Path == "/noua" { 580 if c := r.Header.Get("User-Agent"); c != "" { 581 t.Errorf("handler got non-empty User-Agent header %q", c) 582 } 583 return 584 } 585 if c := r.Header.Get("User-Agent"); c != explicitUA { 586 t.Errorf("handler got unexpected User-Agent header %q", c) 587 } 588 })) 589 defer backend.Close() 590 backendURL, err := url.Parse(backend.URL) 591 if err != nil { 592 t.Fatal(err) 593 } 594 proxyHandler := NewSingleHostReverseProxy(backendURL) 595 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 596 frontend := httptest.NewServer(proxyHandler) 597 defer frontend.Close() 598 frontendClient := frontend.Client() 599 600 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 601 getReq.Header.Set("User-Agent", explicitUA) 602 getReq.Close = true 603 res, err := frontendClient.Do(getReq) 604 if err != nil { 605 t.Fatalf("Get: %v", err) 606 } 607 res.Body.Close() 608 609 getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) 610 getReq.Header.Set("User-Agent", "") 611 getReq.Close = true 612 res, err = frontendClient.Do(getReq) 613 if err != nil { 614 t.Fatalf("Get: %v", err) 615 } 616 res.Body.Close() 617 } 618 619 type bufferPool struct { 620 get func() []byte 621 put func([]byte) 622 } 623 624 func (bp bufferPool) Get() []byte { return bp.get() } 625 func (bp bufferPool) Put(v []byte) { bp.put(v) } 626 627 func TestReverseProxyGetPutBuffer(t *testing.T) { 628 const msg = "hi" 629 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 630 io.WriteString(w, msg) 631 })) 632 defer backend.Close() 633 634 backendURL, err := url.Parse(backend.URL) 635 if err != nil { 636 t.Fatal(err) 637 } 638 639 var ( 640 mu sync.Mutex 641 log []string 642 ) 643 addLog := func(event string) { 644 mu.Lock() 645 defer mu.Unlock() 646 log = append(log, event) 647 } 648 rp := NewSingleHostReverseProxy(backendURL) 649 const size = 1234 650 rp.BufferPool = bufferPool{ 651 get: func() []byte { 652 addLog("getBuf") 653 return make([]byte, size) 654 }, 655 put: func(p []byte) { 656 addLog("putBuf-" + strconv.Itoa(len(p))) 657 }, 658 } 659 frontend := httptest.NewServer(rp) 660 defer frontend.Close() 661 662 req, _ := http.NewRequest("GET", frontend.URL, nil) 663 req.Close = true 664 res, err := frontend.Client().Do(req) 665 if err != nil { 666 t.Fatalf("Get: %v", err) 667 } 668 slurp, err := io.ReadAll(res.Body) 669 res.Body.Close() 670 if err != nil { 671 t.Fatalf("reading body: %v", err) 672 } 673 if string(slurp) != msg { 674 t.Errorf("msg = %q; want %q", slurp, msg) 675 } 676 wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} 677 mu.Lock() 678 defer mu.Unlock() 679 if !reflect.DeepEqual(log, wantLog) { 680 t.Errorf("Log events = %q; want %q", log, wantLog) 681 } 682 } 683 684 func TestReverseProxy_Post(t *testing.T) { 685 const backendResponse = "I am the backend" 686 const backendStatus = 200 687 var requestBody = bytes.Repeat([]byte("a"), 1<<20) 688 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 689 slurp, err := io.ReadAll(r.Body) 690 if err != nil { 691 t.Errorf("Backend body read = %v", err) 692 } 693 if len(slurp) != len(requestBody) { 694 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) 695 } 696 if !bytes.Equal(slurp, requestBody) { 697 t.Error("Backend read wrong request body.") // 1MB; omitting details 698 } 699 w.Write([]byte(backendResponse)) 700 })) 701 defer backend.Close() 702 backendURL, err := url.Parse(backend.URL) 703 if err != nil { 704 t.Fatal(err) 705 } 706 proxyHandler := NewSingleHostReverseProxy(backendURL) 707 frontend := httptest.NewServer(proxyHandler) 708 defer frontend.Close() 709 710 postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) 711 res, err := frontend.Client().Do(postReq) 712 if err != nil { 713 t.Fatalf("Do: %v", err) 714 } 715 if g, e := res.StatusCode, backendStatus; g != e { 716 t.Errorf("got res.StatusCode %d; expected %d", g, e) 717 } 718 bodyBytes, _ := io.ReadAll(res.Body) 719 if g, e := string(bodyBytes), backendResponse; g != e { 720 t.Errorf("got body %q; expected %q", g, e) 721 } 722 } 723 724 type RoundTripperFunc func(*http.Request) (*http.Response, error) 725 726 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 727 return fn(req) 728 } 729 730 // Issue 16036: send a Request with a nil Body when possible 731 func TestReverseProxy_NilBody(t *testing.T) { 732 backendURL, _ := url.Parse("http://fake.tld/") 733 proxyHandler := NewSingleHostReverseProxy(backendURL) 734 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 735 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 736 if req.Body != nil { 737 t.Error("Body != nil; want a nil Body") 738 } 739 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 740 }) 741 frontend := httptest.NewServer(proxyHandler) 742 defer frontend.Close() 743 744 res, err := frontend.Client().Get(frontend.URL) 745 if err != nil { 746 t.Fatal(err) 747 } 748 defer res.Body.Close() 749 if res.StatusCode != 502 { 750 t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) 751 } 752 } 753 754 // Issue 33142: always allocate the request headers 755 func TestReverseProxy_AllocatedHeader(t *testing.T) { 756 proxyHandler := new(ReverseProxy) 757 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 758 proxyHandler.Director = func(*http.Request) {} // noop 759 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 760 if req.Header == nil { 761 t.Error("Header == nil; want a non-nil Header") 762 } 763 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 764 }) 765 766 proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ 767 Method: "GET", 768 URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, 769 Proto: "HTTP/1.0", 770 ProtoMajor: 1, 771 }) 772 } 773 774 // Issue 14237. Test ModifyResponse and that an error from it 775 // causes the proxy to return StatusBadGateway, or StatusOK otherwise. 776 func TestReverseProxyModifyResponse(t *testing.T) { 777 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 778 w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) 779 })) 780 defer backendServer.Close() 781 782 rpURL, _ := url.Parse(backendServer.URL) 783 rproxy := NewSingleHostReverseProxy(rpURL) 784 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 785 rproxy.ModifyResponse = func(resp *http.Response) error { 786 if resp.Header.Get("X-Hit-Mod") != "true" { 787 return fmt.Errorf("tried to by-pass proxy") 788 } 789 return nil 790 } 791 792 frontendProxy := httptest.NewServer(rproxy) 793 defer frontendProxy.Close() 794 795 tests := []struct { 796 url string 797 wantCode int 798 }{ 799 {frontendProxy.URL + "/mod", http.StatusOK}, 800 {frontendProxy.URL + "/schedule", http.StatusBadGateway}, 801 } 802 803 for i, tt := range tests { 804 resp, err := http.Get(tt.url) 805 if err != nil { 806 t.Fatalf("failed to reach proxy: %v", err) 807 } 808 if g, e := resp.StatusCode, tt.wantCode; g != e { 809 t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) 810 } 811 resp.Body.Close() 812 } 813 } 814 815 type failingRoundTripper struct{} 816 817 func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 818 return nil, errors.New("some error") 819 } 820 821 type staticResponseRoundTripper struct{ res *http.Response } 822 823 func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 824 return rt.res, nil 825 } 826 827 func TestReverseProxyErrorHandler(t *testing.T) { 828 tests := []struct { 829 name string 830 wantCode int 831 errorHandler func(http.ResponseWriter, *http.Request, error) 832 transport http.RoundTripper // defaults to failingRoundTripper 833 modifyResponse func(*http.Response) error 834 }{ 835 { 836 name: "default", 837 wantCode: http.StatusBadGateway, 838 }, 839 { 840 name: "errorhandler", 841 wantCode: http.StatusTeapot, 842 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 843 }, 844 { 845 name: "modifyresponse_noerr", 846 transport: staticResponseRoundTripper{ 847 &http.Response{StatusCode: 345, Body: http.NoBody}, 848 }, 849 modifyResponse: func(res *http.Response) error { 850 res.StatusCode++ 851 return nil 852 }, 853 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 854 wantCode: 346, 855 }, 856 { 857 name: "modifyresponse_err", 858 transport: staticResponseRoundTripper{ 859 &http.Response{StatusCode: 345, Body: http.NoBody}, 860 }, 861 modifyResponse: func(res *http.Response) error { 862 res.StatusCode++ 863 return errors.New("some error to trigger errorHandler") 864 }, 865 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 866 wantCode: http.StatusTeapot, 867 }, 868 } 869 870 for _, tt := range tests { 871 t.Run(tt.name, func(t *testing.T) { 872 target := &url.URL{ 873 Scheme: "http", 874 Host: "dummy.tld", 875 Path: "/", 876 } 877 rproxy := NewSingleHostReverseProxy(target) 878 rproxy.Transport = tt.transport 879 rproxy.ModifyResponse = tt.modifyResponse 880 if rproxy.Transport == nil { 881 rproxy.Transport = failingRoundTripper{} 882 } 883 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 884 if tt.errorHandler != nil { 885 rproxy.ErrorHandler = tt.errorHandler 886 } 887 frontendProxy := httptest.NewServer(rproxy) 888 defer frontendProxy.Close() 889 890 resp, err := http.Get(frontendProxy.URL + "/test") 891 if err != nil { 892 t.Fatalf("failed to reach proxy: %v", err) 893 } 894 if g, e := resp.StatusCode, tt.wantCode; g != e { 895 t.Errorf("got res.StatusCode %d; expected %d", g, e) 896 } 897 resp.Body.Close() 898 }) 899 } 900 } 901 902 // Issue 16659: log errors from short read 903 func TestReverseProxy_CopyBuffer(t *testing.T) { 904 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 905 out := "this call was relayed by the reverse proxy" 906 // Coerce a wrong content length to induce io.UnexpectedEOF 907 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 908 fmt.Fprintln(w, out) 909 })) 910 defer backendServer.Close() 911 912 rpURL, err := url.Parse(backendServer.URL) 913 if err != nil { 914 t.Fatal(err) 915 } 916 917 var proxyLog bytes.Buffer 918 rproxy := NewSingleHostReverseProxy(rpURL) 919 rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) 920 donec := make(chan bool, 1) 921 frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 922 defer func() { donec <- true }() 923 rproxy.ServeHTTP(w, r) 924 })) 925 defer frontendProxy.Close() 926 927 if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { 928 t.Fatalf("want non-nil error") 929 } 930 // The race detector complains about the proxyLog usage in logf in copyBuffer 931 // and our usage below with proxyLog.Bytes() so we're explicitly using a 932 // channel to ensure that the ReverseProxy's ServeHTTP is done before we 933 // continue after Get. 934 <-donec 935 936 expected := []string{ 937 "EOF", 938 "read", 939 } 940 for _, phrase := range expected { 941 if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { 942 t.Errorf("expected log to contain phrase %q", phrase) 943 } 944 } 945 } 946 947 type staticTransport struct { 948 res *http.Response 949 } 950 951 func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { 952 return t.res, nil 953 } 954 955 func BenchmarkServeHTTP(b *testing.B) { 956 res := &http.Response{ 957 StatusCode: 200, 958 Body: io.NopCloser(strings.NewReader("")), 959 } 960 proxy := &ReverseProxy{ 961 Director: func(*http.Request) {}, 962 Transport: &staticTransport{res}, 963 } 964 965 w := httptest.NewRecorder() 966 r := httptest.NewRequest("GET", "/", nil) 967 968 b.ReportAllocs() 969 for i := 0; i < b.N; i++ { 970 proxy.ServeHTTP(w, r) 971 } 972 } 973 974 func TestServeHTTPDeepCopy(t *testing.T) { 975 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 976 w.Write([]byte("Hello Gopher!")) 977 })) 978 defer backend.Close() 979 backendURL, err := url.Parse(backend.URL) 980 if err != nil { 981 t.Fatal(err) 982 } 983 984 type result struct { 985 before, after string 986 } 987 988 resultChan := make(chan result, 1) 989 proxyHandler := NewSingleHostReverseProxy(backendURL) 990 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 991 before := r.URL.String() 992 proxyHandler.ServeHTTP(w, r) 993 after := r.URL.String() 994 resultChan <- result{before: before, after: after} 995 })) 996 defer frontend.Close() 997 998 want := result{before: "/", after: "/"} 999 1000 res, err := frontend.Client().Get(frontend.URL) 1001 if err != nil { 1002 t.Fatalf("Do: %v", err) 1003 } 1004 res.Body.Close() 1005 1006 got := <-resultChan 1007 if got != want { 1008 t.Errorf("got = %+v; want = %+v", got, want) 1009 } 1010 } 1011 1012 // Issue 18327: verify we always do a deep copy of the Request.Header map 1013 // before any mutations. 1014 func TestClonesRequestHeaders(t *testing.T) { 1015 log.SetOutput(io.Discard) 1016 defer log.SetOutput(os.Stderr) 1017 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1018 req.RemoteAddr = "1.2.3.4:56789" 1019 rp := &ReverseProxy{ 1020 Director: func(req *http.Request) { 1021 req.Header.Set("From-Director", "1") 1022 }, 1023 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { 1024 if v := req.Header.Get("From-Director"); v != "1" { 1025 t.Errorf("From-Directory value = %q; want 1", v) 1026 } 1027 return nil, io.EOF 1028 }), 1029 } 1030 rp.ServeHTTP(httptest.NewRecorder(), req) 1031 1032 if req.Header.Get("From-Director") == "1" { 1033 t.Error("Director header mutation modified caller's request") 1034 } 1035 if req.Header.Get("X-Forwarded-For") != "" { 1036 t.Error("X-Forward-For header mutation modified caller's request") 1037 } 1038 1039 } 1040 1041 type roundTripperFunc func(req *http.Request) (*http.Response, error) 1042 1043 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 1044 return fn(req) 1045 } 1046 1047 func TestModifyResponseClosesBody(t *testing.T) { 1048 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1049 req.RemoteAddr = "1.2.3.4:56789" 1050 closeCheck := new(checkCloser) 1051 logBuf := new(bytes.Buffer) 1052 outErr := errors.New("ModifyResponse error") 1053 rp := &ReverseProxy{ 1054 Director: func(req *http.Request) {}, 1055 Transport: &staticTransport{&http.Response{ 1056 StatusCode: 200, 1057 Body: closeCheck, 1058 }}, 1059 ErrorLog: log.New(logBuf, "", 0), 1060 ModifyResponse: func(*http.Response) error { 1061 return outErr 1062 }, 1063 } 1064 rec := httptest.NewRecorder() 1065 rp.ServeHTTP(rec, req) 1066 res := rec.Result() 1067 if g, e := res.StatusCode, http.StatusBadGateway; g != e { 1068 t.Errorf("got res.StatusCode %d; expected %d", g, e) 1069 } 1070 if !closeCheck.closed { 1071 t.Errorf("body should have been closed") 1072 } 1073 if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { 1074 t.Errorf("ErrorLog %q does not contain %q", g, e) 1075 } 1076 } 1077 1078 type checkCloser struct { 1079 closed bool 1080 } 1081 1082 func (cc *checkCloser) Close() error { 1083 cc.closed = true 1084 return nil 1085 } 1086 1087 func (cc *checkCloser) Read(b []byte) (int, error) { 1088 return len(b), nil 1089 } 1090 1091 // Issue 23643: panic on body copy error 1092 func TestReverseProxy_PanicBodyError(t *testing.T) { 1093 log.SetOutput(io.Discard) 1094 defer log.SetOutput(os.Stderr) 1095 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1096 out := "this call was relayed by the reverse proxy" 1097 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1098 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1099 fmt.Fprintln(w, out) 1100 })) 1101 defer backendServer.Close() 1102 1103 rpURL, err := url.Parse(backendServer.URL) 1104 if err != nil { 1105 t.Fatal(err) 1106 } 1107 1108 rproxy := NewSingleHostReverseProxy(rpURL) 1109 1110 // Ensure that the handler panics when the body read encounters an 1111 // io.ErrUnexpectedEOF 1112 defer func() { 1113 err := recover() 1114 if err == nil { 1115 t.Fatal("handler should have panicked") 1116 } 1117 if err != http.ErrAbortHandler { 1118 t.Fatal("expected ErrAbortHandler, got", err) 1119 } 1120 }() 1121 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1122 rproxy.ServeHTTP(httptest.NewRecorder(), req) 1123 } 1124 1125 // Issue #46866: panic without closing incoming request body causes a panic 1126 func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) { 1127 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1128 out := "this call was relayed by the reverse proxy" 1129 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1130 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1131 fmt.Fprintln(w, out) 1132 })) 1133 defer backend.Close() 1134 backendURL, err := url.Parse(backend.URL) 1135 if err != nil { 1136 t.Fatal(err) 1137 } 1138 proxyHandler := NewSingleHostReverseProxy(backendURL) 1139 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1140 frontend := httptest.NewServer(proxyHandler) 1141 defer frontend.Close() 1142 frontendClient := frontend.Client() 1143 1144 var wg sync.WaitGroup 1145 for i := 0; i < 2; i++ { 1146 wg.Add(1) 1147 go func() { 1148 defer wg.Done() 1149 for j := 0; j < 10; j++ { 1150 const reqLen = 6 * 1024 * 1024 1151 req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) 1152 req.ContentLength = reqLen 1153 resp, _ := frontendClient.Transport.RoundTrip(req) 1154 if resp != nil { 1155 io.Copy(io.Discard, resp.Body) 1156 resp.Body.Close() 1157 } 1158 } 1159 }() 1160 } 1161 wg.Wait() 1162 } 1163 1164 func TestSelectFlushInterval(t *testing.T) { 1165 tests := []struct { 1166 name string 1167 p *ReverseProxy 1168 res *http.Response 1169 want time.Duration 1170 }{ 1171 { 1172 name: "default", 1173 res: &http.Response{}, 1174 p: &ReverseProxy{FlushInterval: 123}, 1175 want: 123, 1176 }, 1177 { 1178 name: "server-sent events overrides non-zero", 1179 res: &http.Response{ 1180 Header: http.Header{ 1181 "Content-Type": {"text/event-stream"}, 1182 }, 1183 }, 1184 p: &ReverseProxy{FlushInterval: 123}, 1185 want: -1, 1186 }, 1187 { 1188 name: "server-sent events overrides zero", 1189 res: &http.Response{ 1190 Header: http.Header{ 1191 "Content-Type": {"text/event-stream"}, 1192 }, 1193 }, 1194 p: &ReverseProxy{FlushInterval: 0}, 1195 want: -1, 1196 }, 1197 { 1198 name: "server-sent events with media-type parameters overrides non-zero", 1199 res: &http.Response{ 1200 Header: http.Header{ 1201 "Content-Type": {"text/event-stream;charset=utf-8"}, 1202 }, 1203 }, 1204 p: &ReverseProxy{FlushInterval: 123}, 1205 want: -1, 1206 }, 1207 { 1208 name: "server-sent events with media-type parameters overrides zero", 1209 res: &http.Response{ 1210 Header: http.Header{ 1211 "Content-Type": {"text/event-stream;charset=utf-8"}, 1212 }, 1213 }, 1214 p: &ReverseProxy{FlushInterval: 0}, 1215 want: -1, 1216 }, 1217 { 1218 name: "Content-Length: -1, overrides non-zero", 1219 res: &http.Response{ 1220 ContentLength: -1, 1221 }, 1222 p: &ReverseProxy{FlushInterval: 123}, 1223 want: -1, 1224 }, 1225 { 1226 name: "Content-Length: -1, overrides zero", 1227 res: &http.Response{ 1228 ContentLength: -1, 1229 }, 1230 p: &ReverseProxy{FlushInterval: 0}, 1231 want: -1, 1232 }, 1233 } 1234 for _, tt := range tests { 1235 t.Run(tt.name, func(t *testing.T) { 1236 got := tt.p.flushInterval(tt.res) 1237 if got != tt.want { 1238 t.Errorf("flushLatency = %v; want %v", got, tt.want) 1239 } 1240 }) 1241 } 1242 } 1243 1244 func TestReverseProxyWebSocket(t *testing.T) { 1245 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1246 if upgradeType(r.Header) != "websocket" { 1247 t.Error("unexpected backend request") 1248 http.Error(w, "unexpected request", 400) 1249 return 1250 } 1251 c, _, err := w.(http.Hijacker).Hijack() 1252 if err != nil { 1253 t.Error(err) 1254 return 1255 } 1256 defer c.Close() 1257 io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") 1258 bs := bufio.NewScanner(c) 1259 if !bs.Scan() { 1260 t.Errorf("backend failed to read line from client: %v", bs.Err()) 1261 return 1262 } 1263 fmt.Fprintf(c, "backend got %q\n", bs.Text()) 1264 })) 1265 defer backendServer.Close() 1266 1267 backURL, _ := url.Parse(backendServer.URL) 1268 rproxy := NewSingleHostReverseProxy(backURL) 1269 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1270 rproxy.ModifyResponse = func(res *http.Response) error { 1271 res.Header.Add("X-Modified", "true") 1272 return nil 1273 } 1274 1275 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1276 rw.Header().Set("X-Header", "X-Value") 1277 rproxy.ServeHTTP(rw, req) 1278 if got, want := rw.Header().Get("X-Modified"), "true"; got != want { 1279 t.Errorf("response writer X-Modified header = %q; want %q", got, want) 1280 } 1281 }) 1282 1283 frontendProxy := httptest.NewServer(handler) 1284 defer frontendProxy.Close() 1285 1286 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1287 req.Header.Set("Connection", "Upgrade") 1288 req.Header.Set("Upgrade", "websocket") 1289 1290 c := frontendProxy.Client() 1291 res, err := c.Do(req) 1292 if err != nil { 1293 t.Fatal(err) 1294 } 1295 if res.StatusCode != 101 { 1296 t.Fatalf("status = %v; want 101", res.Status) 1297 } 1298 1299 got := res.Header.Get("X-Header") 1300 want := "X-Value" 1301 if got != want { 1302 t.Errorf("Header(XHeader) = %q; want %q", got, want) 1303 } 1304 1305 if !ascii.EqualFold(upgradeType(res.Header), "websocket") { 1306 t.Fatalf("not websocket upgrade; got %#v", res.Header) 1307 } 1308 rwc, ok := res.Body.(io.ReadWriteCloser) 1309 if !ok { 1310 t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) 1311 } 1312 defer rwc.Close() 1313 1314 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1315 t.Errorf("response X-Modified header = %q; want %q", got, want) 1316 } 1317 1318 io.WriteString(rwc, "Hello\n") 1319 bs := bufio.NewScanner(rwc) 1320 if !bs.Scan() { 1321 t.Fatalf("Scan: %v", bs.Err()) 1322 } 1323 got = bs.Text() 1324 want = `backend got "Hello"` 1325 if got != want { 1326 t.Errorf("got %#q, want %#q", got, want) 1327 } 1328 } 1329 1330 func TestReverseProxyWebSocketCancellation(t *testing.T) { 1331 n := 5 1332 triggerCancelCh := make(chan bool, n) 1333 nthResponse := func(i int) string { 1334 return fmt.Sprintf("backend response #%d\n", i) 1335 } 1336 terminalMsg := "final message" 1337 1338 cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1339 if g, ws := upgradeType(r.Header), "websocket"; g != ws { 1340 t.Errorf("Unexpected upgrade type %q, want %q", g, ws) 1341 http.Error(w, "Unexpected request", 400) 1342 return 1343 } 1344 conn, bufrw, err := w.(http.Hijacker).Hijack() 1345 if err != nil { 1346 t.Error(err) 1347 return 1348 } 1349 defer conn.Close() 1350 1351 upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" 1352 if _, err := io.WriteString(conn, upgradeMsg); err != nil { 1353 t.Error(err) 1354 return 1355 } 1356 if _, _, err := bufrw.ReadLine(); err != nil { 1357 t.Errorf("Failed to read line from client: %v", err) 1358 return 1359 } 1360 1361 for i := 0; i < n; i++ { 1362 if _, err := bufrw.WriteString(nthResponse(i)); err != nil { 1363 select { 1364 case <-triggerCancelCh: 1365 default: 1366 t.Errorf("Writing response #%d failed: %v", i, err) 1367 } 1368 return 1369 } 1370 bufrw.Flush() 1371 time.Sleep(time.Second) 1372 } 1373 if _, err := bufrw.WriteString(terminalMsg); err != nil { 1374 select { 1375 case <-triggerCancelCh: 1376 default: 1377 t.Errorf("Failed to write terminal message: %v", err) 1378 } 1379 } 1380 bufrw.Flush() 1381 })) 1382 defer cst.Close() 1383 1384 backendURL, _ := url.Parse(cst.URL) 1385 rproxy := NewSingleHostReverseProxy(backendURL) 1386 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1387 rproxy.ModifyResponse = func(res *http.Response) error { 1388 res.Header.Add("X-Modified", "true") 1389 return nil 1390 } 1391 1392 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1393 rw.Header().Set("X-Header", "X-Value") 1394 ctx, cancel := context.WithCancel(req.Context()) 1395 go func() { 1396 <-triggerCancelCh 1397 cancel() 1398 }() 1399 rproxy.ServeHTTP(rw, req.WithContext(ctx)) 1400 }) 1401 1402 frontendProxy := httptest.NewServer(handler) 1403 defer frontendProxy.Close() 1404 1405 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1406 req.Header.Set("Connection", "Upgrade") 1407 req.Header.Set("Upgrade", "websocket") 1408 1409 res, err := frontendProxy.Client().Do(req) 1410 if err != nil { 1411 t.Fatalf("Dialing to frontend proxy: %v", err) 1412 } 1413 defer res.Body.Close() 1414 if g, w := res.StatusCode, 101; g != w { 1415 t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) 1416 } 1417 1418 if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { 1419 t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) 1420 } 1421 1422 if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) { 1423 t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) 1424 } 1425 1426 rwc, ok := res.Body.(io.ReadWriteCloser) 1427 if !ok { 1428 t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) 1429 } 1430 1431 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1432 t.Errorf("response X-Modified header = %q; want %q", got, want) 1433 } 1434 1435 if _, err := io.WriteString(rwc, "Hello\n"); err != nil { 1436 t.Fatalf("Failed to write first message: %v", err) 1437 } 1438 1439 // Read loop. 1440 1441 br := bufio.NewReader(rwc) 1442 for { 1443 line, err := br.ReadString('\n') 1444 switch { 1445 case line == terminalMsg: // this case before "err == io.EOF" 1446 t.Fatalf("The websocket request was not canceled, unfortunately!") 1447 1448 case err == io.EOF: 1449 return 1450 1451 case err != nil: 1452 t.Fatalf("Unexpected error: %v", err) 1453 1454 case line == nthResponse(0): // We've gotten the first response back 1455 // Let's trigger a cancel. 1456 close(triggerCancelCh) 1457 } 1458 } 1459 } 1460 1461 func TestUnannouncedTrailer(t *testing.T) { 1462 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1463 w.WriteHeader(http.StatusOK) 1464 w.(http.Flusher).Flush() 1465 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 1466 })) 1467 defer backend.Close() 1468 backendURL, err := url.Parse(backend.URL) 1469 if err != nil { 1470 t.Fatal(err) 1471 } 1472 proxyHandler := NewSingleHostReverseProxy(backendURL) 1473 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1474 frontend := httptest.NewServer(proxyHandler) 1475 defer frontend.Close() 1476 frontendClient := frontend.Client() 1477 1478 res, err := frontendClient.Get(frontend.URL) 1479 if err != nil { 1480 t.Fatalf("Get: %v", err) 1481 } 1482 1483 io.ReadAll(res.Body) 1484 1485 if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { 1486 t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) 1487 } 1488 1489 } 1490 1491 func TestSingleJoinSlash(t *testing.T) { 1492 tests := []struct { 1493 slasha string 1494 slashb string 1495 expected string 1496 }{ 1497 {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1498 {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1499 {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, 1500 {"https://www.google.com", "", "https://www.google.com/"}, 1501 {"", "favicon.ico", "/favicon.ico"}, 1502 } 1503 for _, tt := range tests { 1504 if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { 1505 t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", 1506 tt.slasha, 1507 tt.slashb, 1508 tt.expected, 1509 got) 1510 } 1511 } 1512 } 1513 1514 func TestJoinURLPath(t *testing.T) { 1515 tests := []struct { 1516 a *url.URL 1517 b *url.URL 1518 wantPath string 1519 wantRaw string 1520 }{ 1521 {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, 1522 {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, 1523 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1524 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1525 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, 1526 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, 1527 } 1528 1529 for _, tt := range tests { 1530 p, rp := joinURLPath(tt.a, tt.b) 1531 if p != tt.wantPath || rp != tt.wantRaw { 1532 t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", 1533 tt.a.Path, tt.a.RawPath, 1534 tt.b.Path, tt.b.RawPath, 1535 tt.wantPath, tt.wantRaw, 1536 p, rp) 1537 } 1538 } 1539 }