github.com/mtsmfm/go/src@v0.0.0-20221020090648-44bdcb9f8fde/net/http/httputil/reverseproxy_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Reverse proxy tests. 6 7 package httputil 8 9 import ( 10 "bufio" 11 "bytes" 12 "context" 13 "errors" 14 "fmt" 15 "io" 16 "log" 17 "net/http" 18 "net/http/httptest" 19 "net/http/httptrace" 20 "github.com/mtsmfm/go/src/net/http/internal/ascii" 21 "net/textproto" 22 "net/url" 23 "os" 24 "reflect" 25 "sort" 26 "strconv" 27 "strings" 28 "sync" 29 "testing" 30 "time" 31 ) 32 33 const fakeHopHeader = "X-Fake-Hop-Header-For-Test" 34 35 func init() { 36 inOurTests = true 37 hopHeaders = append(hopHeaders, fakeHopHeader) 38 } 39 40 func TestReverseProxy(t *testing.T) { 41 const backendResponse = "I am the backend" 42 const backendStatus = 404 43 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 44 if r.Method == "GET" && r.FormValue("mode") == "hangup" { 45 c, _, _ := w.(http.Hijacker).Hijack() 46 c.Close() 47 return 48 } 49 if len(r.TransferEncoding) > 0 { 50 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) 51 } 52 if r.Header.Get("X-Forwarded-For") == "" { 53 t.Errorf("didn't get X-Forwarded-For header") 54 } 55 if r.Header.Get("X-Forwarded-Host") == "" { 56 t.Errorf("didn't get X-Forwarded-Host header") 57 } 58 if r.Header.Get("X-Forwarded-Proto") == "" { 59 t.Errorf("didn't get X-Forwarded-Proto header") 60 } 61 if c := r.Header.Get("Connection"); c != "" { 62 t.Errorf("handler got Connection header value %q", c) 63 } 64 if c := r.Header.Get("Te"); c != "trailers" { 65 t.Errorf("handler got Te header value %q; want 'trailers'", c) 66 } 67 if c := r.Header.Get("Upgrade"); c != "" { 68 t.Errorf("handler got Upgrade header value %q", c) 69 } 70 if c := r.Header.Get("Proxy-Connection"); c != "" { 71 t.Errorf("handler got Proxy-Connection header value %q", c) 72 } 73 if g, e := r.Host, "some-name"; g != e { 74 t.Errorf("backend got Host header %q, want %q", g, e) 75 } 76 w.Header().Set("Trailers", "not a special header field name") 77 w.Header().Set("Trailer", "X-Trailer") 78 w.Header().Set("X-Foo", "bar") 79 w.Header().Set("Upgrade", "foo") 80 w.Header().Set(fakeHopHeader, "foo") 81 w.Header().Add("X-Multi-Value", "foo") 82 w.Header().Add("X-Multi-Value", "bar") 83 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) 84 w.WriteHeader(backendStatus) 85 w.Write([]byte(backendResponse)) 86 w.Header().Set("X-Trailer", "trailer_value") 87 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 88 })) 89 defer backend.Close() 90 backendURL, err := url.Parse(backend.URL) 91 if err != nil { 92 t.Fatal(err) 93 } 94 proxyHandler := NewSingleHostReverseProxy(backendURL) 95 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 96 frontend := httptest.NewServer(proxyHandler) 97 defer frontend.Close() 98 frontendClient := frontend.Client() 99 100 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 101 getReq.Host = "some-name" 102 getReq.Header.Set("Connection", "close, TE") 103 getReq.Header.Add("Te", "foo") 104 getReq.Header.Add("Te", "bar, trailers") 105 getReq.Header.Set("Proxy-Connection", "should be deleted") 106 getReq.Header.Set("Upgrade", "foo") 107 getReq.Close = true 108 res, err := frontendClient.Do(getReq) 109 if err != nil { 110 t.Fatalf("Get: %v", err) 111 } 112 if g, e := res.StatusCode, backendStatus; g != e { 113 t.Errorf("got res.StatusCode %d; expected %d", g, e) 114 } 115 if g, e := res.Header.Get("X-Foo"), "bar"; g != e { 116 t.Errorf("got X-Foo %q; expected %q", g, e) 117 } 118 if c := res.Header.Get(fakeHopHeader); c != "" { 119 t.Errorf("got %s header value %q", fakeHopHeader, c) 120 } 121 if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { 122 t.Errorf("header Trailers = %q; want %q", g, e) 123 } 124 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { 125 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) 126 } 127 if g, e := len(res.Header["Set-Cookie"]), 1; g != e { 128 t.Fatalf("got %d SetCookies, want %d", g, e) 129 } 130 if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { 131 t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) 132 } 133 if cookie := res.Cookies()[0]; cookie.Name != "flavor" { 134 t.Errorf("unexpected cookie %q", cookie.Name) 135 } 136 bodyBytes, _ := io.ReadAll(res.Body) 137 if g, e := string(bodyBytes), backendResponse; g != e { 138 t.Errorf("got body %q; expected %q", g, e) 139 } 140 if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { 141 t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) 142 } 143 if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { 144 t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) 145 } 146 147 // Test that a backend failing to be reached or one which doesn't return 148 // a response results in a StatusBadGateway. 149 getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) 150 getReq.Close = true 151 res, err = frontendClient.Do(getReq) 152 if err != nil { 153 t.Fatal(err) 154 } 155 res.Body.Close() 156 if res.StatusCode != http.StatusBadGateway { 157 t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) 158 } 159 160 } 161 162 // Issue 16875: remove any proxied headers mentioned in the "Connection" 163 // header value. 164 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { 165 const fakeConnectionToken = "X-Fake-Connection-Token" 166 const backendResponse = "I am the backend" 167 168 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 169 // in the Request's Connection header. 170 const someConnHeader = "X-Some-Conn-Header" 171 172 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 173 if c := r.Header.Get("Connection"); c != "" { 174 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 175 } 176 if c := r.Header.Get(fakeConnectionToken); c != "" { 177 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 178 } 179 if c := r.Header.Get(someConnHeader); c != "" { 180 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 181 } 182 w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) 183 w.Header().Add("Connection", someConnHeader) 184 w.Header().Set(someConnHeader, "should be deleted") 185 w.Header().Set(fakeConnectionToken, "should be deleted") 186 io.WriteString(w, backendResponse) 187 })) 188 defer backend.Close() 189 backendURL, err := url.Parse(backend.URL) 190 if err != nil { 191 t.Fatal(err) 192 } 193 proxyHandler := NewSingleHostReverseProxy(backendURL) 194 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 195 proxyHandler.ServeHTTP(w, r) 196 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 197 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 198 } 199 if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { 200 t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") 201 } 202 c := r.Header["Connection"] 203 var cf []string 204 for _, f := range c { 205 for _, sf := range strings.Split(f, ",") { 206 if sf = strings.TrimSpace(sf); sf != "" { 207 cf = append(cf, sf) 208 } 209 } 210 } 211 sort.Strings(cf) 212 expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} 213 sort.Strings(expectedValues) 214 if !reflect.DeepEqual(cf, expectedValues) { 215 t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) 216 } 217 })) 218 defer frontend.Close() 219 220 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 221 getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) 222 getReq.Header.Add("Connection", someConnHeader) 223 getReq.Header.Set(someConnHeader, "should be deleted") 224 getReq.Header.Set(fakeConnectionToken, "should be deleted") 225 res, err := frontend.Client().Do(getReq) 226 if err != nil { 227 t.Fatalf("Get: %v", err) 228 } 229 defer res.Body.Close() 230 bodyBytes, err := io.ReadAll(res.Body) 231 if err != nil { 232 t.Fatalf("reading body: %v", err) 233 } 234 if got, want := string(bodyBytes), backendResponse; got != want { 235 t.Errorf("got body %q; want %q", got, want) 236 } 237 if c := res.Header.Get("Connection"); c != "" { 238 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 239 } 240 if c := res.Header.Get(someConnHeader); c != "" { 241 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 242 } 243 if c := res.Header.Get(fakeConnectionToken); c != "" { 244 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 245 } 246 } 247 248 func TestReverseProxyStripEmptyConnection(t *testing.T) { 249 // See Issue 46313. 250 const backendResponse = "I am the backend" 251 252 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 253 // in the Request's Connection header. 254 const someConnHeader = "X-Some-Conn-Header" 255 256 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 257 if c := r.Header.Values("Connection"); len(c) != 0 { 258 t.Errorf("handler got header %q = %v; want empty", "Connection", c) 259 } 260 if c := r.Header.Get(someConnHeader); c != "" { 261 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 262 } 263 w.Header().Add("Connection", "") 264 w.Header().Add("Connection", someConnHeader) 265 w.Header().Set(someConnHeader, "should be deleted") 266 io.WriteString(w, backendResponse) 267 })) 268 defer backend.Close() 269 backendURL, err := url.Parse(backend.URL) 270 if err != nil { 271 t.Fatal(err) 272 } 273 proxyHandler := NewSingleHostReverseProxy(backendURL) 274 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 275 proxyHandler.ServeHTTP(w, r) 276 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 277 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 278 } 279 })) 280 defer frontend.Close() 281 282 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 283 getReq.Header.Add("Connection", "") 284 getReq.Header.Add("Connection", someConnHeader) 285 getReq.Header.Set(someConnHeader, "should be deleted") 286 res, err := frontend.Client().Do(getReq) 287 if err != nil { 288 t.Fatalf("Get: %v", err) 289 } 290 defer res.Body.Close() 291 bodyBytes, err := io.ReadAll(res.Body) 292 if err != nil { 293 t.Fatalf("reading body: %v", err) 294 } 295 if got, want := string(bodyBytes), backendResponse; got != want { 296 t.Errorf("got body %q; want %q", got, want) 297 } 298 if c := res.Header.Get("Connection"); c != "" { 299 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 300 } 301 if c := res.Header.Get(someConnHeader); c != "" { 302 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 303 } 304 } 305 306 func TestXForwardedFor(t *testing.T) { 307 const prevForwardedFor = "client ip" 308 const backendResponse = "I am the backend" 309 const backendStatus = 404 310 const host = "some-name" 311 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 312 if r.Header.Get("X-Forwarded-For") == "" { 313 t.Errorf("didn't get X-Forwarded-For header") 314 } 315 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { 316 t.Errorf("X-Forwarded-For didn't contain prior data") 317 } 318 if got, want := r.Header.Get("X-Forwarded-Host"), host; got != want { 319 t.Errorf("X-Forwarded-Host = %q, want %q", got, want) 320 } 321 if got, want := r.Header.Get("X-Forwarded-Proto"), "http"; got != want { 322 t.Errorf("X-Forwarded-Proto = %q, want %q", got, want) 323 } 324 w.WriteHeader(backendStatus) 325 w.Write([]byte(backendResponse)) 326 })) 327 defer backend.Close() 328 backendURL, err := url.Parse(backend.URL) 329 if err != nil { 330 t.Fatal(err) 331 } 332 proxyHandler := NewSingleHostReverseProxy(backendURL) 333 frontend := httptest.NewServer(proxyHandler) 334 defer frontend.Close() 335 336 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 337 getReq.Host = host 338 getReq.Header.Set("Connection", "close") 339 getReq.Header.Set("X-Forwarded-For", prevForwardedFor) 340 getReq.Close = true 341 res, err := frontend.Client().Do(getReq) 342 if err != nil { 343 t.Fatalf("Get: %v", err) 344 } 345 if g, e := res.StatusCode, backendStatus; g != e { 346 t.Errorf("got res.StatusCode %d; expected %d", g, e) 347 } 348 bodyBytes, _ := io.ReadAll(res.Body) 349 if g, e := string(bodyBytes), backendResponse; g != e { 350 t.Errorf("got body %q; expected %q", g, e) 351 } 352 } 353 354 func TestXForwardedProtoTLS(t *testing.T) { 355 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 356 if got, want := r.Header.Get("X-Forwarded-Proto"), "https"; got != want { 357 t.Errorf("X-Forwarded-Proto = %q, want %q", got, want) 358 } 359 })) 360 defer backend.Close() 361 backendURL, err := url.Parse(backend.URL) 362 if err != nil { 363 t.Fatal(err) 364 } 365 proxyHandler := NewSingleHostReverseProxy(backendURL) 366 frontend := httptest.NewTLSServer(proxyHandler) 367 defer frontend.Close() 368 369 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 370 getReq.Host = "some-host" 371 _, err = frontend.Client().Do(getReq) 372 if err != nil { 373 t.Fatalf("Get: %v", err) 374 } 375 } 376 377 // Issue 38079: don't append to X-Forwarded-For if it's present but nil 378 func TestXForwardedFor_Omit(t *testing.T) { 379 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 380 for _, h := range []string{"X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto"} { 381 if v := r.Header.Get(h); v != "" { 382 t.Errorf("got %v header: %q", h, v) 383 } 384 } 385 w.Write([]byte("hi")) 386 })) 387 defer backend.Close() 388 backendURL, err := url.Parse(backend.URL) 389 if err != nil { 390 t.Fatal(err) 391 } 392 proxyHandler := NewSingleHostReverseProxy(backendURL) 393 frontend := httptest.NewServer(proxyHandler) 394 defer frontend.Close() 395 396 oldDirector := proxyHandler.Director 397 proxyHandler.Director = func(r *http.Request) { 398 r.Header["X-Forwarded-For"] = nil 399 r.Header["X-Forwarded-Host"] = nil 400 r.Header["X-Forwarded-Proto"] = nil 401 oldDirector(r) 402 } 403 404 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 405 getReq.Host = "some-name" 406 getReq.Close = true 407 res, err := frontend.Client().Do(getReq) 408 if err != nil { 409 t.Fatalf("Get: %v", err) 410 } 411 res.Body.Close() 412 } 413 414 func TestReverseProxyRewriteStripsForwarded(t *testing.T) { 415 headers := []string{ 416 "Forwarded", 417 "X-Forwarded-For", 418 "X-Forwarded-Host", 419 "X-Forwarded-Proto", 420 } 421 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 422 for _, h := range headers { 423 if v := r.Header.Get(h); v != "" { 424 t.Errorf("got %v header: %q", h, v) 425 } 426 } 427 })) 428 defer backend.Close() 429 backendURL, err := url.Parse(backend.URL) 430 if err != nil { 431 t.Fatal(err) 432 } 433 proxyHandler := &ReverseProxy{ 434 Rewrite: func(r *ProxyRequest) { 435 r.SetURL(backendURL) 436 }, 437 } 438 frontend := httptest.NewServer(proxyHandler) 439 defer frontend.Close() 440 441 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 442 getReq.Host = "some-name" 443 getReq.Close = true 444 for _, h := range headers { 445 getReq.Header.Set(h, "x") 446 } 447 res, err := frontend.Client().Do(getReq) 448 if err != nil { 449 t.Fatalf("Get: %v", err) 450 } 451 res.Body.Close() 452 } 453 454 var proxyQueryTests = []struct { 455 baseSuffix string // suffix to add to backend URL 456 reqSuffix string // suffix to add to frontend's request URL 457 want string // what backend should see for final request URL (without ?) 458 }{ 459 {"", "", ""}, 460 {"?sta=tic", "?us=er", "sta=tic&us=er"}, 461 {"", "?us=er", "us=er"}, 462 {"?sta=tic", "", "sta=tic"}, 463 } 464 465 func TestReverseProxyQuery(t *testing.T) { 466 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 467 w.Header().Set("X-Got-Query", r.URL.RawQuery) 468 w.Write([]byte("hi")) 469 })) 470 defer backend.Close() 471 472 for i, tt := range proxyQueryTests { 473 backendURL, err := url.Parse(backend.URL + tt.baseSuffix) 474 if err != nil { 475 t.Fatal(err) 476 } 477 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) 478 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) 479 req.Close = true 480 res, err := frontend.Client().Do(req) 481 if err != nil { 482 t.Fatalf("%d. Get: %v", i, err) 483 } 484 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { 485 t.Errorf("%d. got query %q; expected %q", i, g, e) 486 } 487 res.Body.Close() 488 frontend.Close() 489 } 490 } 491 492 func TestReverseProxyFlushInterval(t *testing.T) { 493 const expected = "hi" 494 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 495 w.Write([]byte(expected)) 496 })) 497 defer backend.Close() 498 499 backendURL, err := url.Parse(backend.URL) 500 if err != nil { 501 t.Fatal(err) 502 } 503 504 proxyHandler := NewSingleHostReverseProxy(backendURL) 505 proxyHandler.FlushInterval = time.Microsecond 506 507 frontend := httptest.NewServer(proxyHandler) 508 defer frontend.Close() 509 510 req, _ := http.NewRequest("GET", frontend.URL, nil) 511 req.Close = true 512 res, err := frontend.Client().Do(req) 513 if err != nil { 514 t.Fatalf("Get: %v", err) 515 } 516 defer res.Body.Close() 517 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { 518 t.Errorf("got body %q; expected %q", bodyBytes, expected) 519 } 520 } 521 522 func TestReverseProxyFlushIntervalHeaders(t *testing.T) { 523 const expected = "hi" 524 stopCh := make(chan struct{}) 525 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 526 w.Header().Add("MyHeader", expected) 527 w.WriteHeader(200) 528 w.(http.Flusher).Flush() 529 <-stopCh 530 })) 531 defer backend.Close() 532 defer close(stopCh) 533 534 backendURL, err := url.Parse(backend.URL) 535 if err != nil { 536 t.Fatal(err) 537 } 538 539 proxyHandler := NewSingleHostReverseProxy(backendURL) 540 proxyHandler.FlushInterval = time.Microsecond 541 542 frontend := httptest.NewServer(proxyHandler) 543 defer frontend.Close() 544 545 req, _ := http.NewRequest("GET", frontend.URL, nil) 546 req.Close = true 547 548 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) 549 defer cancel() 550 req = req.WithContext(ctx) 551 552 res, err := frontend.Client().Do(req) 553 if err != nil { 554 t.Fatalf("Get: %v", err) 555 } 556 defer res.Body.Close() 557 558 if res.Header.Get("MyHeader") != expected { 559 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) 560 } 561 } 562 563 func TestReverseProxyCancellation(t *testing.T) { 564 const backendResponse = "I am the backend" 565 566 reqInFlight := make(chan struct{}) 567 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 568 close(reqInFlight) // cause the client to cancel its request 569 570 select { 571 case <-time.After(10 * time.Second): 572 // Note: this should only happen in broken implementations, and the 573 // closenotify case should be instantaneous. 574 t.Error("Handler never saw CloseNotify") 575 return 576 case <-w.(http.CloseNotifier).CloseNotify(): 577 } 578 579 w.WriteHeader(http.StatusOK) 580 w.Write([]byte(backendResponse)) 581 })) 582 583 defer backend.Close() 584 585 backend.Config.ErrorLog = log.New(io.Discard, "", 0) 586 587 backendURL, err := url.Parse(backend.URL) 588 if err != nil { 589 t.Fatal(err) 590 } 591 592 proxyHandler := NewSingleHostReverseProxy(backendURL) 593 594 // Discards errors of the form: 595 // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection 596 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) 597 598 frontend := httptest.NewServer(proxyHandler) 599 defer frontend.Close() 600 frontendClient := frontend.Client() 601 602 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 603 go func() { 604 <-reqInFlight 605 frontendClient.Transport.(*http.Transport).CancelRequest(getReq) 606 }() 607 res, err := frontendClient.Do(getReq) 608 if res != nil { 609 t.Errorf("got response %v; want nil", res.Status) 610 } 611 if err == nil { 612 // This should be an error like: 613 // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079: 614 // use of closed network connection 615 t.Error("Server.Client().Do() returned nil error; want non-nil error") 616 } 617 } 618 619 func req(t *testing.T, v string) *http.Request { 620 req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) 621 if err != nil { 622 t.Fatal(err) 623 } 624 return req 625 } 626 627 // Issue 12344 628 func TestNilBody(t *testing.T) { 629 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 630 w.Write([]byte("hi")) 631 })) 632 defer backend.Close() 633 634 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 635 backURL, _ := url.Parse(backend.URL) 636 rp := NewSingleHostReverseProxy(backURL) 637 r := req(t, "GET / HTTP/1.0\r\n\r\n") 638 r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working 639 rp.ServeHTTP(w, r) 640 })) 641 defer frontend.Close() 642 643 res, err := http.Get(frontend.URL) 644 if err != nil { 645 t.Fatal(err) 646 } 647 defer res.Body.Close() 648 slurp, err := io.ReadAll(res.Body) 649 if err != nil { 650 t.Fatal(err) 651 } 652 if string(slurp) != "hi" { 653 t.Errorf("Got %q; want %q", slurp, "hi") 654 } 655 } 656 657 // Issue 15524 658 func TestUserAgentHeader(t *testing.T) { 659 var gotUA string 660 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 661 gotUA = r.Header.Get("User-Agent") 662 })) 663 defer backend.Close() 664 backendURL, err := url.Parse(backend.URL) 665 if err != nil { 666 t.Fatal(err) 667 } 668 669 proxyHandler := new(ReverseProxy) 670 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 671 proxyHandler.Director = func(req *http.Request) { 672 req.URL = backendURL 673 } 674 frontend := httptest.NewServer(proxyHandler) 675 defer frontend.Close() 676 frontendClient := frontend.Client() 677 678 for _, sentUA := range []string{"explicit UA", ""} { 679 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 680 getReq.Header.Set("User-Agent", sentUA) 681 getReq.Close = true 682 res, err := frontendClient.Do(getReq) 683 if err != nil { 684 t.Fatalf("Get: %v", err) 685 } 686 res.Body.Close() 687 if got, want := gotUA, sentUA; got != want { 688 t.Errorf("got forwarded User-Agent %q, want %q", got, want) 689 } 690 } 691 } 692 693 type bufferPool struct { 694 get func() []byte 695 put func([]byte) 696 } 697 698 func (bp bufferPool) Get() []byte { return bp.get() } 699 func (bp bufferPool) Put(v []byte) { bp.put(v) } 700 701 func TestReverseProxyGetPutBuffer(t *testing.T) { 702 const msg = "hi" 703 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 704 io.WriteString(w, msg) 705 })) 706 defer backend.Close() 707 708 backendURL, err := url.Parse(backend.URL) 709 if err != nil { 710 t.Fatal(err) 711 } 712 713 var ( 714 mu sync.Mutex 715 log []string 716 ) 717 addLog := func(event string) { 718 mu.Lock() 719 defer mu.Unlock() 720 log = append(log, event) 721 } 722 rp := NewSingleHostReverseProxy(backendURL) 723 const size = 1234 724 rp.BufferPool = bufferPool{ 725 get: func() []byte { 726 addLog("getBuf") 727 return make([]byte, size) 728 }, 729 put: func(p []byte) { 730 addLog("putBuf-" + strconv.Itoa(len(p))) 731 }, 732 } 733 frontend := httptest.NewServer(rp) 734 defer frontend.Close() 735 736 req, _ := http.NewRequest("GET", frontend.URL, nil) 737 req.Close = true 738 res, err := frontend.Client().Do(req) 739 if err != nil { 740 t.Fatalf("Get: %v", err) 741 } 742 slurp, err := io.ReadAll(res.Body) 743 res.Body.Close() 744 if err != nil { 745 t.Fatalf("reading body: %v", err) 746 } 747 if string(slurp) != msg { 748 t.Errorf("msg = %q; want %q", slurp, msg) 749 } 750 wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} 751 mu.Lock() 752 defer mu.Unlock() 753 if !reflect.DeepEqual(log, wantLog) { 754 t.Errorf("Log events = %q; want %q", log, wantLog) 755 } 756 } 757 758 func TestReverseProxy_Post(t *testing.T) { 759 const backendResponse = "I am the backend" 760 const backendStatus = 200 761 var requestBody = bytes.Repeat([]byte("a"), 1<<20) 762 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 763 slurp, err := io.ReadAll(r.Body) 764 if err != nil { 765 t.Errorf("Backend body read = %v", err) 766 } 767 if len(slurp) != len(requestBody) { 768 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) 769 } 770 if !bytes.Equal(slurp, requestBody) { 771 t.Error("Backend read wrong request body.") // 1MB; omitting details 772 } 773 w.Write([]byte(backendResponse)) 774 })) 775 defer backend.Close() 776 backendURL, err := url.Parse(backend.URL) 777 if err != nil { 778 t.Fatal(err) 779 } 780 proxyHandler := NewSingleHostReverseProxy(backendURL) 781 frontend := httptest.NewServer(proxyHandler) 782 defer frontend.Close() 783 784 postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) 785 res, err := frontend.Client().Do(postReq) 786 if err != nil { 787 t.Fatalf("Do: %v", err) 788 } 789 if g, e := res.StatusCode, backendStatus; g != e { 790 t.Errorf("got res.StatusCode %d; expected %d", g, e) 791 } 792 bodyBytes, _ := io.ReadAll(res.Body) 793 if g, e := string(bodyBytes), backendResponse; g != e { 794 t.Errorf("got body %q; expected %q", g, e) 795 } 796 } 797 798 type RoundTripperFunc func(*http.Request) (*http.Response, error) 799 800 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 801 return fn(req) 802 } 803 804 // Issue 16036: send a Request with a nil Body when possible 805 func TestReverseProxy_NilBody(t *testing.T) { 806 backendURL, _ := url.Parse("http://fake.tld/") 807 proxyHandler := NewSingleHostReverseProxy(backendURL) 808 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 809 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 810 if req.Body != nil { 811 t.Error("Body != nil; want a nil Body") 812 } 813 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 814 }) 815 frontend := httptest.NewServer(proxyHandler) 816 defer frontend.Close() 817 818 res, err := frontend.Client().Get(frontend.URL) 819 if err != nil { 820 t.Fatal(err) 821 } 822 defer res.Body.Close() 823 if res.StatusCode != 502 { 824 t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) 825 } 826 } 827 828 // Issue 33142: always allocate the request headers 829 func TestReverseProxy_AllocatedHeader(t *testing.T) { 830 proxyHandler := new(ReverseProxy) 831 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 832 proxyHandler.Director = func(*http.Request) {} // noop 833 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 834 if req.Header == nil { 835 t.Error("Header == nil; want a non-nil Header") 836 } 837 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 838 }) 839 840 proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ 841 Method: "GET", 842 URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, 843 Proto: "HTTP/1.0", 844 ProtoMajor: 1, 845 }) 846 } 847 848 // Issue 14237. Test ModifyResponse and that an error from it 849 // causes the proxy to return StatusBadGateway, or StatusOK otherwise. 850 func TestReverseProxyModifyResponse(t *testing.T) { 851 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 852 w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) 853 })) 854 defer backendServer.Close() 855 856 rpURL, _ := url.Parse(backendServer.URL) 857 rproxy := NewSingleHostReverseProxy(rpURL) 858 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 859 rproxy.ModifyResponse = func(resp *http.Response) error { 860 if resp.Header.Get("X-Hit-Mod") != "true" { 861 return fmt.Errorf("tried to by-pass proxy") 862 } 863 return nil 864 } 865 866 frontendProxy := httptest.NewServer(rproxy) 867 defer frontendProxy.Close() 868 869 tests := []struct { 870 url string 871 wantCode int 872 }{ 873 {frontendProxy.URL + "/mod", http.StatusOK}, 874 {frontendProxy.URL + "/schedule", http.StatusBadGateway}, 875 } 876 877 for i, tt := range tests { 878 resp, err := http.Get(tt.url) 879 if err != nil { 880 t.Fatalf("failed to reach proxy: %v", err) 881 } 882 if g, e := resp.StatusCode, tt.wantCode; g != e { 883 t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) 884 } 885 resp.Body.Close() 886 } 887 } 888 889 type failingRoundTripper struct{} 890 891 func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 892 return nil, errors.New("some error") 893 } 894 895 type staticResponseRoundTripper struct{ res *http.Response } 896 897 func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 898 return rt.res, nil 899 } 900 901 func TestReverseProxyErrorHandler(t *testing.T) { 902 tests := []struct { 903 name string 904 wantCode int 905 errorHandler func(http.ResponseWriter, *http.Request, error) 906 transport http.RoundTripper // defaults to failingRoundTripper 907 modifyResponse func(*http.Response) error 908 }{ 909 { 910 name: "default", 911 wantCode: http.StatusBadGateway, 912 }, 913 { 914 name: "errorhandler", 915 wantCode: http.StatusTeapot, 916 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 917 }, 918 { 919 name: "modifyresponse_noerr", 920 transport: staticResponseRoundTripper{ 921 &http.Response{StatusCode: 345, Body: http.NoBody}, 922 }, 923 modifyResponse: func(res *http.Response) error { 924 res.StatusCode++ 925 return nil 926 }, 927 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 928 wantCode: 346, 929 }, 930 { 931 name: "modifyresponse_err", 932 transport: staticResponseRoundTripper{ 933 &http.Response{StatusCode: 345, Body: http.NoBody}, 934 }, 935 modifyResponse: func(res *http.Response) error { 936 res.StatusCode++ 937 return errors.New("some error to trigger errorHandler") 938 }, 939 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 940 wantCode: http.StatusTeapot, 941 }, 942 } 943 944 for _, tt := range tests { 945 t.Run(tt.name, func(t *testing.T) { 946 target := &url.URL{ 947 Scheme: "http", 948 Host: "dummy.tld", 949 Path: "/", 950 } 951 rproxy := NewSingleHostReverseProxy(target) 952 rproxy.Transport = tt.transport 953 rproxy.ModifyResponse = tt.modifyResponse 954 if rproxy.Transport == nil { 955 rproxy.Transport = failingRoundTripper{} 956 } 957 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 958 if tt.errorHandler != nil { 959 rproxy.ErrorHandler = tt.errorHandler 960 } 961 frontendProxy := httptest.NewServer(rproxy) 962 defer frontendProxy.Close() 963 964 resp, err := http.Get(frontendProxy.URL + "/test") 965 if err != nil { 966 t.Fatalf("failed to reach proxy: %v", err) 967 } 968 if g, e := resp.StatusCode, tt.wantCode; g != e { 969 t.Errorf("got res.StatusCode %d; expected %d", g, e) 970 } 971 resp.Body.Close() 972 }) 973 } 974 } 975 976 // Issue 16659: log errors from short read 977 func TestReverseProxy_CopyBuffer(t *testing.T) { 978 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 979 out := "this call was relayed by the reverse proxy" 980 // Coerce a wrong content length to induce io.UnexpectedEOF 981 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 982 fmt.Fprintln(w, out) 983 })) 984 defer backendServer.Close() 985 986 rpURL, err := url.Parse(backendServer.URL) 987 if err != nil { 988 t.Fatal(err) 989 } 990 991 var proxyLog bytes.Buffer 992 rproxy := NewSingleHostReverseProxy(rpURL) 993 rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) 994 donec := make(chan bool, 1) 995 frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 996 defer func() { donec <- true }() 997 rproxy.ServeHTTP(w, r) 998 })) 999 defer frontendProxy.Close() 1000 1001 if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { 1002 t.Fatalf("want non-nil error") 1003 } 1004 // The race detector complains about the proxyLog usage in logf in copyBuffer 1005 // and our usage below with proxyLog.Bytes() so we're explicitly using a 1006 // channel to ensure that the ReverseProxy's ServeHTTP is done before we 1007 // continue after Get. 1008 <-donec 1009 1010 expected := []string{ 1011 "EOF", 1012 "read", 1013 } 1014 for _, phrase := range expected { 1015 if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { 1016 t.Errorf("expected log to contain phrase %q", phrase) 1017 } 1018 } 1019 } 1020 1021 type staticTransport struct { 1022 res *http.Response 1023 } 1024 1025 func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { 1026 return t.res, nil 1027 } 1028 1029 func BenchmarkServeHTTP(b *testing.B) { 1030 res := &http.Response{ 1031 StatusCode: 200, 1032 Body: io.NopCloser(strings.NewReader("")), 1033 } 1034 proxy := &ReverseProxy{ 1035 Director: func(*http.Request) {}, 1036 Transport: &staticTransport{res}, 1037 } 1038 1039 w := httptest.NewRecorder() 1040 r := httptest.NewRequest("GET", "/", nil) 1041 1042 b.ReportAllocs() 1043 for i := 0; i < b.N; i++ { 1044 proxy.ServeHTTP(w, r) 1045 } 1046 } 1047 1048 func TestServeHTTPDeepCopy(t *testing.T) { 1049 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1050 w.Write([]byte("Hello Gopher!")) 1051 })) 1052 defer backend.Close() 1053 backendURL, err := url.Parse(backend.URL) 1054 if err != nil { 1055 t.Fatal(err) 1056 } 1057 1058 type result struct { 1059 before, after string 1060 } 1061 1062 resultChan := make(chan result, 1) 1063 proxyHandler := NewSingleHostReverseProxy(backendURL) 1064 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1065 before := r.URL.String() 1066 proxyHandler.ServeHTTP(w, r) 1067 after := r.URL.String() 1068 resultChan <- result{before: before, after: after} 1069 })) 1070 defer frontend.Close() 1071 1072 want := result{before: "/", after: "/"} 1073 1074 res, err := frontend.Client().Get(frontend.URL) 1075 if err != nil { 1076 t.Fatalf("Do: %v", err) 1077 } 1078 res.Body.Close() 1079 1080 got := <-resultChan 1081 if got != want { 1082 t.Errorf("got = %+v; want = %+v", got, want) 1083 } 1084 } 1085 1086 // Issue 18327: verify we always do a deep copy of the Request.Header map 1087 // before any mutations. 1088 func TestClonesRequestHeaders(t *testing.T) { 1089 log.SetOutput(io.Discard) 1090 defer log.SetOutput(os.Stderr) 1091 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1092 req.RemoteAddr = "1.2.3.4:56789" 1093 rp := &ReverseProxy{ 1094 Director: func(req *http.Request) { 1095 req.Header.Set("From-Director", "1") 1096 }, 1097 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { 1098 if v := req.Header.Get("From-Director"); v != "1" { 1099 t.Errorf("From-Directory value = %q; want 1", v) 1100 } 1101 return nil, io.EOF 1102 }), 1103 } 1104 rp.ServeHTTP(httptest.NewRecorder(), req) 1105 1106 for _, h := range []string{ 1107 "From-Director", 1108 "X-Forwarded-For", 1109 "X-Forwarded-Host", 1110 "X-Forwarded-Proto", 1111 } { 1112 if req.Header.Get(h) != "" { 1113 t.Errorf("%v header mutation modified caller's request", h) 1114 } 1115 } 1116 } 1117 1118 type roundTripperFunc func(req *http.Request) (*http.Response, error) 1119 1120 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 1121 return fn(req) 1122 } 1123 1124 func TestModifyResponseClosesBody(t *testing.T) { 1125 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1126 req.RemoteAddr = "1.2.3.4:56789" 1127 closeCheck := new(checkCloser) 1128 logBuf := new(strings.Builder) 1129 outErr := errors.New("ModifyResponse error") 1130 rp := &ReverseProxy{ 1131 Director: func(req *http.Request) {}, 1132 Transport: &staticTransport{&http.Response{ 1133 StatusCode: 200, 1134 Body: closeCheck, 1135 }}, 1136 ErrorLog: log.New(logBuf, "", 0), 1137 ModifyResponse: func(*http.Response) error { 1138 return outErr 1139 }, 1140 } 1141 rec := httptest.NewRecorder() 1142 rp.ServeHTTP(rec, req) 1143 res := rec.Result() 1144 if g, e := res.StatusCode, http.StatusBadGateway; g != e { 1145 t.Errorf("got res.StatusCode %d; expected %d", g, e) 1146 } 1147 if !closeCheck.closed { 1148 t.Errorf("body should have been closed") 1149 } 1150 if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { 1151 t.Errorf("ErrorLog %q does not contain %q", g, e) 1152 } 1153 } 1154 1155 type checkCloser struct { 1156 closed bool 1157 } 1158 1159 func (cc *checkCloser) Close() error { 1160 cc.closed = true 1161 return nil 1162 } 1163 1164 func (cc *checkCloser) Read(b []byte) (int, error) { 1165 return len(b), nil 1166 } 1167 1168 // Issue 23643: panic on body copy error 1169 func TestReverseProxy_PanicBodyError(t *testing.T) { 1170 log.SetOutput(io.Discard) 1171 defer log.SetOutput(os.Stderr) 1172 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1173 out := "this call was relayed by the reverse proxy" 1174 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1175 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1176 fmt.Fprintln(w, out) 1177 })) 1178 defer backendServer.Close() 1179 1180 rpURL, err := url.Parse(backendServer.URL) 1181 if err != nil { 1182 t.Fatal(err) 1183 } 1184 1185 rproxy := NewSingleHostReverseProxy(rpURL) 1186 1187 // Ensure that the handler panics when the body read encounters an 1188 // io.ErrUnexpectedEOF 1189 defer func() { 1190 err := recover() 1191 if err == nil { 1192 t.Fatal("handler should have panicked") 1193 } 1194 if err != http.ErrAbortHandler { 1195 t.Fatal("expected ErrAbortHandler, got", err) 1196 } 1197 }() 1198 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1199 rproxy.ServeHTTP(httptest.NewRecorder(), req) 1200 } 1201 1202 // Issue #46866: panic without closing incoming request body causes a panic 1203 func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) { 1204 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1205 out := "this call was relayed by the reverse proxy" 1206 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1207 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1208 fmt.Fprintln(w, out) 1209 })) 1210 defer backend.Close() 1211 backendURL, err := url.Parse(backend.URL) 1212 if err != nil { 1213 t.Fatal(err) 1214 } 1215 proxyHandler := NewSingleHostReverseProxy(backendURL) 1216 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1217 frontend := httptest.NewServer(proxyHandler) 1218 defer frontend.Close() 1219 frontendClient := frontend.Client() 1220 1221 var wg sync.WaitGroup 1222 for i := 0; i < 2; i++ { 1223 wg.Add(1) 1224 go func() { 1225 defer wg.Done() 1226 for j := 0; j < 10; j++ { 1227 const reqLen = 6 * 1024 * 1024 1228 req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) 1229 req.ContentLength = reqLen 1230 resp, _ := frontendClient.Transport.RoundTrip(req) 1231 if resp != nil { 1232 io.Copy(io.Discard, resp.Body) 1233 resp.Body.Close() 1234 } 1235 } 1236 }() 1237 } 1238 wg.Wait() 1239 } 1240 1241 func TestSelectFlushInterval(t *testing.T) { 1242 tests := []struct { 1243 name string 1244 p *ReverseProxy 1245 res *http.Response 1246 want time.Duration 1247 }{ 1248 { 1249 name: "default", 1250 res: &http.Response{}, 1251 p: &ReverseProxy{FlushInterval: 123}, 1252 want: 123, 1253 }, 1254 { 1255 name: "server-sent events overrides non-zero", 1256 res: &http.Response{ 1257 Header: http.Header{ 1258 "Content-Type": {"text/event-stream"}, 1259 }, 1260 }, 1261 p: &ReverseProxy{FlushInterval: 123}, 1262 want: -1, 1263 }, 1264 { 1265 name: "server-sent events overrides zero", 1266 res: &http.Response{ 1267 Header: http.Header{ 1268 "Content-Type": {"text/event-stream"}, 1269 }, 1270 }, 1271 p: &ReverseProxy{FlushInterval: 0}, 1272 want: -1, 1273 }, 1274 { 1275 name: "server-sent events with media-type parameters overrides non-zero", 1276 res: &http.Response{ 1277 Header: http.Header{ 1278 "Content-Type": {"text/event-stream;charset=utf-8"}, 1279 }, 1280 }, 1281 p: &ReverseProxy{FlushInterval: 123}, 1282 want: -1, 1283 }, 1284 { 1285 name: "server-sent events with media-type parameters overrides zero", 1286 res: &http.Response{ 1287 Header: http.Header{ 1288 "Content-Type": {"text/event-stream;charset=utf-8"}, 1289 }, 1290 }, 1291 p: &ReverseProxy{FlushInterval: 0}, 1292 want: -1, 1293 }, 1294 { 1295 name: "Content-Length: -1, overrides non-zero", 1296 res: &http.Response{ 1297 ContentLength: -1, 1298 }, 1299 p: &ReverseProxy{FlushInterval: 123}, 1300 want: -1, 1301 }, 1302 { 1303 name: "Content-Length: -1, overrides zero", 1304 res: &http.Response{ 1305 ContentLength: -1, 1306 }, 1307 p: &ReverseProxy{FlushInterval: 0}, 1308 want: -1, 1309 }, 1310 } 1311 for _, tt := range tests { 1312 t.Run(tt.name, func(t *testing.T) { 1313 got := tt.p.flushInterval(tt.res) 1314 if got != tt.want { 1315 t.Errorf("flushLatency = %v; want %v", got, tt.want) 1316 } 1317 }) 1318 } 1319 } 1320 1321 func TestReverseProxyWebSocket(t *testing.T) { 1322 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1323 if upgradeType(r.Header) != "websocket" { 1324 t.Error("unexpected backend request") 1325 http.Error(w, "unexpected request", 400) 1326 return 1327 } 1328 c, _, err := w.(http.Hijacker).Hijack() 1329 if err != nil { 1330 t.Error(err) 1331 return 1332 } 1333 defer c.Close() 1334 io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") 1335 bs := bufio.NewScanner(c) 1336 if !bs.Scan() { 1337 t.Errorf("backend failed to read line from client: %v", bs.Err()) 1338 return 1339 } 1340 fmt.Fprintf(c, "backend got %q\n", bs.Text()) 1341 })) 1342 defer backendServer.Close() 1343 1344 backURL, _ := url.Parse(backendServer.URL) 1345 rproxy := NewSingleHostReverseProxy(backURL) 1346 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1347 rproxy.ModifyResponse = func(res *http.Response) error { 1348 res.Header.Add("X-Modified", "true") 1349 return nil 1350 } 1351 1352 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1353 rw.Header().Set("X-Header", "X-Value") 1354 rproxy.ServeHTTP(rw, req) 1355 if got, want := rw.Header().Get("X-Modified"), "true"; got != want { 1356 t.Errorf("response writer X-Modified header = %q; want %q", got, want) 1357 } 1358 }) 1359 1360 frontendProxy := httptest.NewServer(handler) 1361 defer frontendProxy.Close() 1362 1363 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1364 req.Header.Set("Connection", "Upgrade") 1365 req.Header.Set("Upgrade", "websocket") 1366 1367 c := frontendProxy.Client() 1368 res, err := c.Do(req) 1369 if err != nil { 1370 t.Fatal(err) 1371 } 1372 if res.StatusCode != 101 { 1373 t.Fatalf("status = %v; want 101", res.Status) 1374 } 1375 1376 got := res.Header.Get("X-Header") 1377 want := "X-Value" 1378 if got != want { 1379 t.Errorf("Header(XHeader) = %q; want %q", got, want) 1380 } 1381 1382 if !ascii.EqualFold(upgradeType(res.Header), "websocket") { 1383 t.Fatalf("not websocket upgrade; got %#v", res.Header) 1384 } 1385 rwc, ok := res.Body.(io.ReadWriteCloser) 1386 if !ok { 1387 t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) 1388 } 1389 defer rwc.Close() 1390 1391 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1392 t.Errorf("response X-Modified header = %q; want %q", got, want) 1393 } 1394 1395 io.WriteString(rwc, "Hello\n") 1396 bs := bufio.NewScanner(rwc) 1397 if !bs.Scan() { 1398 t.Fatalf("Scan: %v", bs.Err()) 1399 } 1400 got = bs.Text() 1401 want = `backend got "Hello"` 1402 if got != want { 1403 t.Errorf("got %#q, want %#q", got, want) 1404 } 1405 } 1406 1407 func TestReverseProxyWebSocketCancellation(t *testing.T) { 1408 n := 5 1409 triggerCancelCh := make(chan bool, n) 1410 nthResponse := func(i int) string { 1411 return fmt.Sprintf("backend response #%d\n", i) 1412 } 1413 terminalMsg := "final message" 1414 1415 cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1416 if g, ws := upgradeType(r.Header), "websocket"; g != ws { 1417 t.Errorf("Unexpected upgrade type %q, want %q", g, ws) 1418 http.Error(w, "Unexpected request", 400) 1419 return 1420 } 1421 conn, bufrw, err := w.(http.Hijacker).Hijack() 1422 if err != nil { 1423 t.Error(err) 1424 return 1425 } 1426 defer conn.Close() 1427 1428 upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" 1429 if _, err := io.WriteString(conn, upgradeMsg); err != nil { 1430 t.Error(err) 1431 return 1432 } 1433 if _, _, err := bufrw.ReadLine(); err != nil { 1434 t.Errorf("Failed to read line from client: %v", err) 1435 return 1436 } 1437 1438 for i := 0; i < n; i++ { 1439 if _, err := bufrw.WriteString(nthResponse(i)); err != nil { 1440 select { 1441 case <-triggerCancelCh: 1442 default: 1443 t.Errorf("Writing response #%d failed: %v", i, err) 1444 } 1445 return 1446 } 1447 bufrw.Flush() 1448 time.Sleep(time.Second) 1449 } 1450 if _, err := bufrw.WriteString(terminalMsg); err != nil { 1451 select { 1452 case <-triggerCancelCh: 1453 default: 1454 t.Errorf("Failed to write terminal message: %v", err) 1455 } 1456 } 1457 bufrw.Flush() 1458 })) 1459 defer cst.Close() 1460 1461 backendURL, _ := url.Parse(cst.URL) 1462 rproxy := NewSingleHostReverseProxy(backendURL) 1463 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1464 rproxy.ModifyResponse = func(res *http.Response) error { 1465 res.Header.Add("X-Modified", "true") 1466 return nil 1467 } 1468 1469 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1470 rw.Header().Set("X-Header", "X-Value") 1471 ctx, cancel := context.WithCancel(req.Context()) 1472 go func() { 1473 <-triggerCancelCh 1474 cancel() 1475 }() 1476 rproxy.ServeHTTP(rw, req.WithContext(ctx)) 1477 }) 1478 1479 frontendProxy := httptest.NewServer(handler) 1480 defer frontendProxy.Close() 1481 1482 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1483 req.Header.Set("Connection", "Upgrade") 1484 req.Header.Set("Upgrade", "websocket") 1485 1486 res, err := frontendProxy.Client().Do(req) 1487 if err != nil { 1488 t.Fatalf("Dialing to frontend proxy: %v", err) 1489 } 1490 defer res.Body.Close() 1491 if g, w := res.StatusCode, 101; g != w { 1492 t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) 1493 } 1494 1495 if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { 1496 t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) 1497 } 1498 1499 if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) { 1500 t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) 1501 } 1502 1503 rwc, ok := res.Body.(io.ReadWriteCloser) 1504 if !ok { 1505 t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) 1506 } 1507 1508 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1509 t.Errorf("response X-Modified header = %q; want %q", got, want) 1510 } 1511 1512 if _, err := io.WriteString(rwc, "Hello\n"); err != nil { 1513 t.Fatalf("Failed to write first message: %v", err) 1514 } 1515 1516 // Read loop. 1517 1518 br := bufio.NewReader(rwc) 1519 for { 1520 line, err := br.ReadString('\n') 1521 switch { 1522 case line == terminalMsg: // this case before "err == io.EOF" 1523 t.Fatalf("The websocket request was not canceled, unfortunately!") 1524 1525 case err == io.EOF: 1526 return 1527 1528 case err != nil: 1529 t.Fatalf("Unexpected error: %v", err) 1530 1531 case line == nthResponse(0): // We've gotten the first response back 1532 // Let's trigger a cancel. 1533 close(triggerCancelCh) 1534 } 1535 } 1536 } 1537 1538 func TestUnannouncedTrailer(t *testing.T) { 1539 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1540 w.WriteHeader(http.StatusOK) 1541 w.(http.Flusher).Flush() 1542 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 1543 })) 1544 defer backend.Close() 1545 backendURL, err := url.Parse(backend.URL) 1546 if err != nil { 1547 t.Fatal(err) 1548 } 1549 proxyHandler := NewSingleHostReverseProxy(backendURL) 1550 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1551 frontend := httptest.NewServer(proxyHandler) 1552 defer frontend.Close() 1553 frontendClient := frontend.Client() 1554 1555 res, err := frontendClient.Get(frontend.URL) 1556 if err != nil { 1557 t.Fatalf("Get: %v", err) 1558 } 1559 1560 io.ReadAll(res.Body) 1561 1562 if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { 1563 t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) 1564 } 1565 1566 } 1567 1568 func TestSetURL(t *testing.T) { 1569 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1570 w.Write([]byte(r.Host)) 1571 })) 1572 defer backend.Close() 1573 backendURL, err := url.Parse(backend.URL) 1574 if err != nil { 1575 t.Fatal(err) 1576 } 1577 proxyHandler := &ReverseProxy{ 1578 Rewrite: func(r *ProxyRequest) { 1579 r.SetURL(backendURL) 1580 }, 1581 } 1582 frontend := httptest.NewServer(proxyHandler) 1583 defer frontend.Close() 1584 frontendClient := frontend.Client() 1585 1586 res, err := frontendClient.Get(frontend.URL) 1587 if err != nil { 1588 t.Fatalf("Get: %v", err) 1589 } 1590 defer res.Body.Close() 1591 1592 body, err := io.ReadAll(res.Body) 1593 if err != nil { 1594 t.Fatalf("Reading body: %v", err) 1595 } 1596 1597 if got, want := string(body), backendURL.Host; got != want { 1598 t.Errorf("backend got Host %q, want %q", got, want) 1599 } 1600 } 1601 1602 func TestSingleJoinSlash(t *testing.T) { 1603 tests := []struct { 1604 slasha string 1605 slashb string 1606 expected string 1607 }{ 1608 {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1609 {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1610 {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, 1611 {"https://www.google.com", "", "https://www.google.com/"}, 1612 {"", "favicon.ico", "/favicon.ico"}, 1613 } 1614 for _, tt := range tests { 1615 if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { 1616 t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", 1617 tt.slasha, 1618 tt.slashb, 1619 tt.expected, 1620 got) 1621 } 1622 } 1623 } 1624 1625 func TestJoinURLPath(t *testing.T) { 1626 tests := []struct { 1627 a *url.URL 1628 b *url.URL 1629 wantPath string 1630 wantRaw string 1631 }{ 1632 {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, 1633 {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, 1634 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1635 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1636 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, 1637 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, 1638 } 1639 1640 for _, tt := range tests { 1641 p, rp := joinURLPath(tt.a, tt.b) 1642 if p != tt.wantPath || rp != tt.wantRaw { 1643 t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", 1644 tt.a.Path, tt.a.RawPath, 1645 tt.b.Path, tt.b.RawPath, 1646 tt.wantPath, tt.wantRaw, 1647 p, rp) 1648 } 1649 } 1650 } 1651 1652 func TestReverseProxyRewriteReplacesOut(t *testing.T) { 1653 const content = "response_content" 1654 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1655 w.Write([]byte(content)) 1656 })) 1657 defer backend.Close() 1658 proxyHandler := &ReverseProxy{ 1659 Rewrite: func(r *ProxyRequest) { 1660 r.Out, _ = http.NewRequest("GET", backend.URL, nil) 1661 }, 1662 } 1663 frontend := httptest.NewServer(proxyHandler) 1664 defer frontend.Close() 1665 1666 res, err := frontend.Client().Get(frontend.URL) 1667 if err != nil { 1668 t.Fatalf("Get: %v", err) 1669 } 1670 defer res.Body.Close() 1671 body, _ := io.ReadAll(res.Body) 1672 if got, want := string(body), content; got != want { 1673 t.Errorf("got response %q, want %q", got, want) 1674 } 1675 } 1676 1677 func Test1xxResponses(t *testing.T) { 1678 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1679 h := w.Header() 1680 h.Add("Link", "</style.css>; rel=preload; as=style") 1681 h.Add("Link", "</script.js>; rel=preload; as=script") 1682 w.WriteHeader(http.StatusEarlyHints) 1683 1684 h.Add("Link", "</foo.js>; rel=preload; as=script") 1685 w.WriteHeader(http.StatusProcessing) 1686 1687 w.Write([]byte("Hello")) 1688 })) 1689 defer backend.Close() 1690 backendURL, err := url.Parse(backend.URL) 1691 if err != nil { 1692 t.Fatal(err) 1693 } 1694 proxyHandler := NewSingleHostReverseProxy(backendURL) 1695 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1696 frontend := httptest.NewServer(proxyHandler) 1697 defer frontend.Close() 1698 frontendClient := frontend.Client() 1699 1700 checkLinkHeaders := func(t *testing.T, expected, got []string) { 1701 t.Helper() 1702 1703 if len(expected) != len(got) { 1704 t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) 1705 } 1706 1707 for i := range expected { 1708 if i >= len(got) { 1709 t.Errorf("Expected %q link header; got nothing", expected[i]) 1710 1711 continue 1712 } 1713 1714 if expected[i] != got[i] { 1715 t.Errorf("Expected %q link header; got %q", expected[i], got[i]) 1716 } 1717 } 1718 } 1719 1720 var respCounter uint8 1721 trace := &httptrace.ClientTrace{ 1722 Got1xxResponse: func(code int, header textproto.MIMEHeader) error { 1723 switch code { 1724 case http.StatusEarlyHints: 1725 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"]) 1726 case http.StatusProcessing: 1727 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"]) 1728 default: 1729 t.Error("Unexpected 1xx response") 1730 } 1731 1732 respCounter++ 1733 1734 return nil 1735 }, 1736 } 1737 req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil) 1738 1739 res, err := frontendClient.Do(req) 1740 if err != nil { 1741 t.Fatalf("Get: %v", err) 1742 } 1743 1744 defer res.Body.Close() 1745 1746 if respCounter != 2 { 1747 t.Errorf("Expected 2 1xx responses; got %d", respCounter) 1748 } 1749 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"]) 1750 1751 body, _ := io.ReadAll(res.Body) 1752 if string(body) != "Hello" { 1753 t.Errorf("Read body %q; want Hello", body) 1754 } 1755 } 1756 1757 const ( 1758 testWantsCleanQuery = true 1759 testWantsRawQuery = false 1760 ) 1761 1762 func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) { 1763 testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { 1764 proxyHandler := NewSingleHostReverseProxy(u) 1765 oldDirector := proxyHandler.Director 1766 proxyHandler.Director = func(r *http.Request) { 1767 oldDirector(r) 1768 } 1769 return proxyHandler 1770 }) 1771 } 1772 1773 func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) { 1774 testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { 1775 proxyHandler := NewSingleHostReverseProxy(u) 1776 oldDirector := proxyHandler.Director 1777 proxyHandler.Director = func(r *http.Request) { 1778 // Parsing the form causes ReverseProxy to remove unparsable 1779 // query parameters before forwarding. 1780 r.FormValue("a") 1781 oldDirector(r) 1782 } 1783 return proxyHandler 1784 }) 1785 } 1786 1787 func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) { 1788 testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { 1789 return &ReverseProxy{ 1790 Rewrite: func(r *ProxyRequest) { 1791 r.SetURL(u) 1792 }, 1793 } 1794 }) 1795 } 1796 1797 func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) { 1798 testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { 1799 return &ReverseProxy{ 1800 Rewrite: func(r *ProxyRequest) { 1801 r.SetURL(u) 1802 r.Out.URL.RawQuery = r.In.URL.RawQuery 1803 }, 1804 } 1805 }) 1806 } 1807 1808 func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) { 1809 const content = "response_content" 1810 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1811 w.Write([]byte(r.URL.RawQuery)) 1812 })) 1813 defer backend.Close() 1814 backendURL, err := url.Parse(backend.URL) 1815 if err != nil { 1816 t.Fatal(err) 1817 } 1818 proxyHandler := newProxy(backendURL) 1819 frontend := httptest.NewServer(proxyHandler) 1820 defer frontend.Close() 1821 1822 // Don't spam output with logs of queries containing semicolons. 1823 backend.Config.ErrorLog = log.New(io.Discard, "", 0) 1824 frontend.Config.ErrorLog = log.New(io.Discard, "", 0) 1825 1826 for _, test := range []struct { 1827 rawQuery string 1828 cleanQuery string 1829 }{{ 1830 rawQuery: "a=1&a=2;b=3", 1831 cleanQuery: "a=1", 1832 }, { 1833 rawQuery: "a=1&a=%zz&b=3", 1834 cleanQuery: "a=1&b=3", 1835 }} { 1836 res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery) 1837 if err != nil { 1838 t.Fatalf("Get: %v", err) 1839 } 1840 defer res.Body.Close() 1841 body, _ := io.ReadAll(res.Body) 1842 wantQuery := test.rawQuery 1843 if wantCleanQuery { 1844 wantQuery = test.cleanQuery 1845 } 1846 if got, want := string(body), wantQuery; got != want { 1847 t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want) 1848 } 1849 } 1850 }