github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/gmhttp/httputil/reverseproxy_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Reverse proxy tests. 6 7 package httputil 8 9 import ( 10 "bufio" 11 "bytes" 12 "context" 13 "errors" 14 "fmt" 15 "io" 16 "log" 17 "net/url" 18 "os" 19 "reflect" 20 "sort" 21 "strconv" 22 "strings" 23 "sync" 24 "testing" 25 "time" 26 27 http "github.com/hxx258456/ccgo/gmhttp" 28 "github.com/hxx258456/ccgo/gmhttp/httptest" 29 "github.com/hxx258456/ccgo/gmhttp/internal/ascii" 30 ) 31 32 const fakeHopHeader = "X-Fake-Hop-Header-For-Test" 33 34 func init() { 35 inOurTests = true 36 hopHeaders = append(hopHeaders, fakeHopHeader) 37 } 38 39 func TestReverseProxy(t *testing.T) { 40 const backendResponse = "I am the backend" 41 const backendStatus = 404 42 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 43 if r.Method == "GET" && r.FormValue("mode") == "hangup" { 44 c, _, _ := w.(http.Hijacker).Hijack() 45 c.Close() 46 return 47 } 48 if len(r.TransferEncoding) > 0 { 49 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) 50 } 51 if r.Header.Get("X-Forwarded-For") == "" { 52 t.Errorf("didn't get X-Forwarded-For header") 53 } 54 if c := r.Header.Get("Connection"); c != "" { 55 t.Errorf("handler got Connection header value %q", c) 56 } 57 if c := r.Header.Get("Te"); c != "trailers" { 58 t.Errorf("handler got Te header value %q; want 'trailers'", c) 59 } 60 if c := r.Header.Get("Upgrade"); c != "" { 61 t.Errorf("handler got Upgrade header value %q", c) 62 } 63 if c := r.Header.Get("Proxy-Connection"); c != "" { 64 t.Errorf("handler got Proxy-Connection header value %q", c) 65 } 66 if g, e := r.Host, "some-name"; g != e { 67 t.Errorf("backend got Host header %q, want %q", g, e) 68 } 69 w.Header().Set("Trailers", "not a special header field name") 70 w.Header().Set("Trailer", "X-Trailer") 71 w.Header().Set("X-Foo", "bar") 72 w.Header().Set("Upgrade", "foo") 73 w.Header().Set(fakeHopHeader, "foo") 74 w.Header().Add("X-Multi-Value", "foo") 75 w.Header().Add("X-Multi-Value", "bar") 76 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) 77 w.WriteHeader(backendStatus) 78 w.Write([]byte(backendResponse)) 79 w.Header().Set("X-Trailer", "trailer_value") 80 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 81 })) 82 defer backend.Close() 83 backendURL, err := url.Parse(backend.URL) 84 if err != nil { 85 t.Fatal(err) 86 } 87 proxyHandler := NewSingleHostReverseProxy(backendURL) 88 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 89 frontend := httptest.NewServer(proxyHandler) 90 defer frontend.Close() 91 frontendClient := frontend.Client() 92 93 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 94 getReq.Host = "some-name" 95 getReq.Header.Set("Connection", "close, TE") 96 getReq.Header.Add("Te", "foo") 97 getReq.Header.Add("Te", "bar, trailers") 98 getReq.Header.Set("Proxy-Connection", "should be deleted") 99 getReq.Header.Set("Upgrade", "foo") 100 getReq.Close = true 101 res, err := frontendClient.Do(getReq) 102 if err != nil { 103 t.Fatalf("Get: %v", err) 104 } 105 if g, e := res.StatusCode, backendStatus; g != e { 106 t.Errorf("got res.StatusCode %d; expected %d", g, e) 107 } 108 if g, e := res.Header.Get("X-Foo"), "bar"; g != e { 109 t.Errorf("got X-Foo %q; expected %q", g, e) 110 } 111 if c := res.Header.Get(fakeHopHeader); c != "" { 112 t.Errorf("got %s header value %q", fakeHopHeader, c) 113 } 114 if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { 115 t.Errorf("header Trailers = %q; want %q", g, e) 116 } 117 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { 118 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) 119 } 120 if g, e := len(res.Header["Set-Cookie"]), 1; g != e { 121 t.Fatalf("got %d SetCookies, want %d", g, e) 122 } 123 if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { 124 t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) 125 } 126 if cookie := res.Cookies()[0]; cookie.Name != "flavor" { 127 t.Errorf("unexpected cookie %q", cookie.Name) 128 } 129 bodyBytes, _ := io.ReadAll(res.Body) 130 if g, e := string(bodyBytes), backendResponse; g != e { 131 t.Errorf("got body %q; expected %q", g, e) 132 } 133 if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { 134 t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) 135 } 136 if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { 137 t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) 138 } 139 140 // Test that a backend failing to be reached or one which doesn't return 141 // a response results in a StatusBadGateway. 142 getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) 143 getReq.Close = true 144 res, err = frontendClient.Do(getReq) 145 if err != nil { 146 t.Fatal(err) 147 } 148 res.Body.Close() 149 if res.StatusCode != http.StatusBadGateway { 150 t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) 151 } 152 153 } 154 155 // Issue 16875: remove any proxied headers mentioned in the "Connection" 156 // header value. 157 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { 158 const fakeConnectionToken = "X-Fake-Connection-Token" 159 const backendResponse = "I am the backend" 160 161 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 162 // in the Request's Connection header. 163 const someConnHeader = "X-Some-Conn-Header" 164 165 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 166 if c := r.Header.Get("Connection"); c != "" { 167 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 168 } 169 if c := r.Header.Get(fakeConnectionToken); c != "" { 170 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 171 } 172 if c := r.Header.Get(someConnHeader); c != "" { 173 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 174 } 175 w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) 176 w.Header().Add("Connection", someConnHeader) 177 w.Header().Set(someConnHeader, "should be deleted") 178 w.Header().Set(fakeConnectionToken, "should be deleted") 179 io.WriteString(w, backendResponse) 180 })) 181 defer backend.Close() 182 backendURL, err := url.Parse(backend.URL) 183 if err != nil { 184 t.Fatal(err) 185 } 186 proxyHandler := NewSingleHostReverseProxy(backendURL) 187 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 188 proxyHandler.ServeHTTP(w, r) 189 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 190 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 191 } 192 if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { 193 t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") 194 } 195 c := r.Header["Connection"] 196 var cf []string 197 for _, f := range c { 198 for _, sf := range strings.Split(f, ",") { 199 if sf = strings.TrimSpace(sf); sf != "" { 200 cf = append(cf, sf) 201 } 202 } 203 } 204 sort.Strings(cf) 205 expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} 206 sort.Strings(expectedValues) 207 if !reflect.DeepEqual(cf, expectedValues) { 208 t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) 209 } 210 })) 211 defer frontend.Close() 212 213 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 214 getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) 215 getReq.Header.Add("Connection", someConnHeader) 216 getReq.Header.Set(someConnHeader, "should be deleted") 217 getReq.Header.Set(fakeConnectionToken, "should be deleted") 218 res, err := frontend.Client().Do(getReq) 219 if err != nil { 220 t.Fatalf("Get: %v", err) 221 } 222 defer res.Body.Close() 223 bodyBytes, err := io.ReadAll(res.Body) 224 if err != nil { 225 t.Fatalf("reading body: %v", err) 226 } 227 if got, want := string(bodyBytes), backendResponse; got != want { 228 t.Errorf("got body %q; want %q", got, want) 229 } 230 if c := res.Header.Get("Connection"); c != "" { 231 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 232 } 233 if c := res.Header.Get(someConnHeader); c != "" { 234 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 235 } 236 if c := res.Header.Get(fakeConnectionToken); c != "" { 237 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 238 } 239 } 240 241 func TestReverseProxyStripEmptyConnection(t *testing.T) { 242 // See Issue 46313. 243 const backendResponse = "I am the backend" 244 245 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 246 // in the Request's Connection header. 247 const someConnHeader = "X-Some-Conn-Header" 248 249 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 250 if c := r.Header.Values("Connection"); len(c) != 0 { 251 t.Errorf("handler got header %q = %v; want empty", "Connection", c) 252 } 253 if c := r.Header.Get(someConnHeader); c != "" { 254 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 255 } 256 w.Header().Add("Connection", "") 257 w.Header().Add("Connection", someConnHeader) 258 w.Header().Set(someConnHeader, "should be deleted") 259 io.WriteString(w, backendResponse) 260 })) 261 defer backend.Close() 262 backendURL, err := url.Parse(backend.URL) 263 if err != nil { 264 t.Fatal(err) 265 } 266 proxyHandler := NewSingleHostReverseProxy(backendURL) 267 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 268 proxyHandler.ServeHTTP(w, r) 269 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 270 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 271 } 272 })) 273 defer frontend.Close() 274 275 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 276 getReq.Header.Add("Connection", "") 277 getReq.Header.Add("Connection", someConnHeader) 278 getReq.Header.Set(someConnHeader, "should be deleted") 279 res, err := frontend.Client().Do(getReq) 280 if err != nil { 281 t.Fatalf("Get: %v", err) 282 } 283 defer res.Body.Close() 284 bodyBytes, err := io.ReadAll(res.Body) 285 if err != nil { 286 t.Fatalf("reading body: %v", err) 287 } 288 if got, want := string(bodyBytes), backendResponse; got != want { 289 t.Errorf("got body %q; want %q", got, want) 290 } 291 if c := res.Header.Get("Connection"); c != "" { 292 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 293 } 294 if c := res.Header.Get(someConnHeader); c != "" { 295 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 296 } 297 } 298 299 func TestXForwardedFor(t *testing.T) { 300 const prevForwardedFor = "client ip" 301 const backendResponse = "I am the backend" 302 const backendStatus = 404 303 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 304 if r.Header.Get("X-Forwarded-For") == "" { 305 t.Errorf("didn't get X-Forwarded-For header") 306 } 307 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { 308 t.Errorf("X-Forwarded-For didn't contain prior data") 309 } 310 w.WriteHeader(backendStatus) 311 w.Write([]byte(backendResponse)) 312 })) 313 defer backend.Close() 314 backendURL, err := url.Parse(backend.URL) 315 if err != nil { 316 t.Fatal(err) 317 } 318 proxyHandler := NewSingleHostReverseProxy(backendURL) 319 frontend := httptest.NewServer(proxyHandler) 320 defer frontend.Close() 321 322 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 323 getReq.Host = "some-name" 324 getReq.Header.Set("Connection", "close") 325 getReq.Header.Set("X-Forwarded-For", prevForwardedFor) 326 getReq.Close = true 327 res, err := frontend.Client().Do(getReq) 328 if err != nil { 329 t.Fatalf("Get: %v", err) 330 } 331 if g, e := res.StatusCode, backendStatus; g != e { 332 t.Errorf("got res.StatusCode %d; expected %d", g, e) 333 } 334 bodyBytes, _ := io.ReadAll(res.Body) 335 if g, e := string(bodyBytes), backendResponse; g != e { 336 t.Errorf("got body %q; expected %q", g, e) 337 } 338 } 339 340 // Issue 38079: don't append to X-Forwarded-For if it's present but nil 341 func TestXForwardedFor_Omit(t *testing.T) { 342 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 343 if v := r.Header.Get("X-Forwarded-For"); v != "" { 344 t.Errorf("got X-Forwarded-For header: %q", v) 345 } 346 w.Write([]byte("hi")) 347 })) 348 defer backend.Close() 349 backendURL, err := url.Parse(backend.URL) 350 if err != nil { 351 t.Fatal(err) 352 } 353 proxyHandler := NewSingleHostReverseProxy(backendURL) 354 frontend := httptest.NewServer(proxyHandler) 355 defer frontend.Close() 356 357 oldDirector := proxyHandler.Director 358 proxyHandler.Director = func(r *http.Request) { 359 r.Header["X-Forwarded-For"] = nil 360 oldDirector(r) 361 } 362 363 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 364 getReq.Host = "some-name" 365 getReq.Close = true 366 res, err := frontend.Client().Do(getReq) 367 if err != nil { 368 t.Fatalf("Get: %v", err) 369 } 370 res.Body.Close() 371 } 372 373 var proxyQueryTests = []struct { 374 baseSuffix string // suffix to add to backend URL 375 reqSuffix string // suffix to add to frontend's request URL 376 want string // what backend should see for final request URL (without ?) 377 }{ 378 {"", "", ""}, 379 {"?sta=tic", "?us=er", "sta=tic&us=er"}, 380 {"", "?us=er", "us=er"}, 381 {"?sta=tic", "", "sta=tic"}, 382 } 383 384 func TestReverseProxyQuery(t *testing.T) { 385 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 386 w.Header().Set("X-Got-Query", r.URL.RawQuery) 387 w.Write([]byte("hi")) 388 })) 389 defer backend.Close() 390 391 for i, tt := range proxyQueryTests { 392 backendURL, err := url.Parse(backend.URL + tt.baseSuffix) 393 if err != nil { 394 t.Fatal(err) 395 } 396 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) 397 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) 398 req.Close = true 399 res, err := frontend.Client().Do(req) 400 if err != nil { 401 t.Fatalf("%d. Get: %v", i, err) 402 } 403 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { 404 t.Errorf("%d. got query %q; expected %q", i, g, e) 405 } 406 res.Body.Close() 407 frontend.Close() 408 } 409 } 410 411 func TestReverseProxyFlushInterval(t *testing.T) { 412 const expected = "hi" 413 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 414 w.Write([]byte(expected)) 415 })) 416 defer backend.Close() 417 418 backendURL, err := url.Parse(backend.URL) 419 if err != nil { 420 t.Fatal(err) 421 } 422 423 proxyHandler := NewSingleHostReverseProxy(backendURL) 424 proxyHandler.FlushInterval = time.Microsecond 425 426 frontend := httptest.NewServer(proxyHandler) 427 defer frontend.Close() 428 429 req, _ := http.NewRequest("GET", frontend.URL, nil) 430 req.Close = true 431 res, err := frontend.Client().Do(req) 432 if err != nil { 433 t.Fatalf("Get: %v", err) 434 } 435 defer res.Body.Close() 436 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { 437 t.Errorf("got body %q; expected %q", bodyBytes, expected) 438 } 439 } 440 441 func TestReverseProxyFlushIntervalHeaders(t *testing.T) { 442 const expected = "hi" 443 stopCh := make(chan struct{}) 444 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 445 w.Header().Add("MyHeader", expected) 446 w.WriteHeader(200) 447 w.(http.Flusher).Flush() 448 <-stopCh 449 })) 450 defer backend.Close() 451 defer close(stopCh) 452 453 backendURL, err := url.Parse(backend.URL) 454 if err != nil { 455 t.Fatal(err) 456 } 457 458 proxyHandler := NewSingleHostReverseProxy(backendURL) 459 proxyHandler.FlushInterval = time.Microsecond 460 461 frontend := httptest.NewServer(proxyHandler) 462 defer frontend.Close() 463 464 req, _ := http.NewRequest("GET", frontend.URL, nil) 465 req.Close = true 466 467 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) 468 defer cancel() 469 req = req.WithContext(ctx) 470 471 res, err := frontend.Client().Do(req) 472 if err != nil { 473 t.Fatalf("Get: %v", err) 474 } 475 defer res.Body.Close() 476 477 if res.Header.Get("MyHeader") != expected { 478 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) 479 } 480 } 481 482 func TestReverseProxyCancellation(t *testing.T) { 483 const backendResponse = "I am the backend" 484 485 reqInFlight := make(chan struct{}) 486 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 487 close(reqInFlight) // cause the client to cancel its request 488 489 select { 490 case <-time.After(10 * time.Second): 491 // Note: this should only happen in broken implementations, and the 492 // closenotify case should be instantaneous. 493 t.Error("Handler never saw CloseNotify") 494 return 495 case <-w.(http.CloseNotifier).CloseNotify(): 496 } 497 498 w.WriteHeader(http.StatusOK) 499 w.Write([]byte(backendResponse)) 500 })) 501 502 defer backend.Close() 503 504 backend.Config.ErrorLog = log.New(io.Discard, "", 0) 505 506 backendURL, err := url.Parse(backend.URL) 507 if err != nil { 508 t.Fatal(err) 509 } 510 511 proxyHandler := NewSingleHostReverseProxy(backendURL) 512 513 // Discards errors of the form: 514 // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection 515 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) 516 517 frontend := httptest.NewServer(proxyHandler) 518 defer frontend.Close() 519 frontendClient := frontend.Client() 520 521 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 522 go func() { 523 <-reqInFlight 524 frontendClient.Transport.(*http.Transport).CancelRequest(getReq) 525 }() 526 res, err := frontendClient.Do(getReq) 527 if res != nil { 528 t.Errorf("got response %v; want nil", res.Status) 529 } 530 if err == nil { 531 // This should be an error like: 532 // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079: 533 // use of closed network connection 534 t.Error("Server.Client().Do() returned nil error; want non-nil error") 535 } 536 } 537 538 func req(t *testing.T, v string) *http.Request { 539 req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) 540 if err != nil { 541 t.Fatal(err) 542 } 543 return req 544 } 545 546 // Issue 12344 547 func TestNilBody(t *testing.T) { 548 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 549 w.Write([]byte("hi")) 550 })) 551 defer backend.Close() 552 553 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 554 backURL, _ := url.Parse(backend.URL) 555 rp := NewSingleHostReverseProxy(backURL) 556 r := req(t, "GET / HTTP/1.0\r\n\r\n") 557 r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working 558 rp.ServeHTTP(w, r) 559 })) 560 defer frontend.Close() 561 562 res, err := http.Get(frontend.URL) 563 if err != nil { 564 t.Fatal(err) 565 } 566 defer res.Body.Close() 567 slurp, err := io.ReadAll(res.Body) 568 if err != nil { 569 t.Fatal(err) 570 } 571 if string(slurp) != "hi" { 572 t.Errorf("Got %q; want %q", slurp, "hi") 573 } 574 } 575 576 // Issue 15524 577 func TestUserAgentHeader(t *testing.T) { 578 const explicitUA = "explicit UA" 579 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 580 if r.URL.Path == "/noua" { 581 if c := r.Header.Get("User-Agent"); c != "" { 582 t.Errorf("handler got non-empty User-Agent header %q", c) 583 } 584 return 585 } 586 if c := r.Header.Get("User-Agent"); c != explicitUA { 587 t.Errorf("handler got unexpected User-Agent header %q", c) 588 } 589 })) 590 defer backend.Close() 591 backendURL, err := url.Parse(backend.URL) 592 if err != nil { 593 t.Fatal(err) 594 } 595 proxyHandler := NewSingleHostReverseProxy(backendURL) 596 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 597 frontend := httptest.NewServer(proxyHandler) 598 defer frontend.Close() 599 frontendClient := frontend.Client() 600 601 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 602 getReq.Header.Set("User-Agent", explicitUA) 603 getReq.Close = true 604 res, err := frontendClient.Do(getReq) 605 if err != nil { 606 t.Fatalf("Get: %v", err) 607 } 608 res.Body.Close() 609 610 getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) 611 getReq.Header.Set("User-Agent", "") 612 getReq.Close = true 613 res, err = frontendClient.Do(getReq) 614 if err != nil { 615 t.Fatalf("Get: %v", err) 616 } 617 res.Body.Close() 618 } 619 620 type bufferPool struct { 621 get func() []byte 622 put func([]byte) 623 } 624 625 func (bp bufferPool) Get() []byte { return bp.get() } 626 func (bp bufferPool) Put(v []byte) { bp.put(v) } 627 628 func TestReverseProxyGetPutBuffer(t *testing.T) { 629 const msg = "hi" 630 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 631 io.WriteString(w, msg) 632 })) 633 defer backend.Close() 634 635 backendURL, err := url.Parse(backend.URL) 636 if err != nil { 637 t.Fatal(err) 638 } 639 640 var ( 641 mu sync.Mutex 642 log []string 643 ) 644 addLog := func(event string) { 645 mu.Lock() 646 defer mu.Unlock() 647 log = append(log, event) 648 } 649 rp := NewSingleHostReverseProxy(backendURL) 650 const size = 1234 651 rp.BufferPool = bufferPool{ 652 get: func() []byte { 653 addLog("getBuf") 654 return make([]byte, size) 655 }, 656 put: func(p []byte) { 657 addLog("putBuf-" + strconv.Itoa(len(p))) 658 }, 659 } 660 frontend := httptest.NewServer(rp) 661 defer frontend.Close() 662 663 req, _ := http.NewRequest("GET", frontend.URL, nil) 664 req.Close = true 665 res, err := frontend.Client().Do(req) 666 if err != nil { 667 t.Fatalf("Get: %v", err) 668 } 669 slurp, err := io.ReadAll(res.Body) 670 res.Body.Close() 671 if err != nil { 672 t.Fatalf("reading body: %v", err) 673 } 674 if string(slurp) != msg { 675 t.Errorf("msg = %q; want %q", slurp, msg) 676 } 677 wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} 678 mu.Lock() 679 defer mu.Unlock() 680 if !reflect.DeepEqual(log, wantLog) { 681 t.Errorf("Log events = %q; want %q", log, wantLog) 682 } 683 } 684 685 func TestReverseProxy_Post(t *testing.T) { 686 const backendResponse = "I am the backend" 687 const backendStatus = 200 688 var requestBody = bytes.Repeat([]byte("a"), 1<<20) 689 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 690 slurp, err := io.ReadAll(r.Body) 691 if err != nil { 692 t.Errorf("Backend body read = %v", err) 693 } 694 if len(slurp) != len(requestBody) { 695 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) 696 } 697 if !bytes.Equal(slurp, requestBody) { 698 t.Error("Backend read wrong request body.") // 1MB; omitting details 699 } 700 w.Write([]byte(backendResponse)) 701 })) 702 defer backend.Close() 703 backendURL, err := url.Parse(backend.URL) 704 if err != nil { 705 t.Fatal(err) 706 } 707 proxyHandler := NewSingleHostReverseProxy(backendURL) 708 frontend := httptest.NewServer(proxyHandler) 709 defer frontend.Close() 710 711 postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) 712 res, err := frontend.Client().Do(postReq) 713 if err != nil { 714 t.Fatalf("Do: %v", err) 715 } 716 if g, e := res.StatusCode, backendStatus; g != e { 717 t.Errorf("got res.StatusCode %d; expected %d", g, e) 718 } 719 bodyBytes, _ := io.ReadAll(res.Body) 720 if g, e := string(bodyBytes), backendResponse; g != e { 721 t.Errorf("got body %q; expected %q", g, e) 722 } 723 } 724 725 type RoundTripperFunc func(*http.Request) (*http.Response, error) 726 727 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 728 return fn(req) 729 } 730 731 // Issue 16036: send a Request with a nil Body when possible 732 func TestReverseProxy_NilBody(t *testing.T) { 733 backendURL, _ := url.Parse("http://fake.tld/") 734 proxyHandler := NewSingleHostReverseProxy(backendURL) 735 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 736 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 737 if req.Body != nil { 738 t.Error("Body != nil; want a nil Body") 739 } 740 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 741 }) 742 frontend := httptest.NewServer(proxyHandler) 743 defer frontend.Close() 744 745 res, err := frontend.Client().Get(frontend.URL) 746 if err != nil { 747 t.Fatal(err) 748 } 749 defer res.Body.Close() 750 if res.StatusCode != 502 { 751 t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) 752 } 753 } 754 755 // Issue 33142: always allocate the request headers 756 func TestReverseProxy_AllocatedHeader(t *testing.T) { 757 proxyHandler := new(ReverseProxy) 758 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 759 proxyHandler.Director = func(*http.Request) {} // noop 760 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 761 if req.Header == nil { 762 t.Error("Header == nil; want a non-nil Header") 763 } 764 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 765 }) 766 767 proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ 768 Method: "GET", 769 URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, 770 Proto: "HTTP/1.0", 771 ProtoMajor: 1, 772 }) 773 } 774 775 // Issue 14237. Test ModifyResponse and that an error from it 776 // causes the proxy to return StatusBadGateway, or StatusOK otherwise. 777 func TestReverseProxyModifyResponse(t *testing.T) { 778 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 779 w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) 780 })) 781 defer backendServer.Close() 782 783 rpURL, _ := url.Parse(backendServer.URL) 784 rproxy := NewSingleHostReverseProxy(rpURL) 785 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 786 rproxy.ModifyResponse = func(resp *http.Response) error { 787 if resp.Header.Get("X-Hit-Mod") != "true" { 788 return fmt.Errorf("tried to by-pass proxy") 789 } 790 return nil 791 } 792 793 frontendProxy := httptest.NewServer(rproxy) 794 defer frontendProxy.Close() 795 796 tests := []struct { 797 url string 798 wantCode int 799 }{ 800 {frontendProxy.URL + "/mod", http.StatusOK}, 801 {frontendProxy.URL + "/schedule", http.StatusBadGateway}, 802 } 803 804 for i, tt := range tests { 805 resp, err := http.Get(tt.url) 806 if err != nil { 807 t.Fatalf("failed to reach proxy: %v", err) 808 } 809 if g, e := resp.StatusCode, tt.wantCode; g != e { 810 t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) 811 } 812 resp.Body.Close() 813 } 814 } 815 816 type failingRoundTripper struct{} 817 818 func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 819 return nil, errors.New("some error") 820 } 821 822 type staticResponseRoundTripper struct{ res *http.Response } 823 824 func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 825 return rt.res, nil 826 } 827 828 func TestReverseProxyErrorHandler(t *testing.T) { 829 tests := []struct { 830 name string 831 wantCode int 832 errorHandler func(http.ResponseWriter, *http.Request, error) 833 transport http.RoundTripper // defaults to failingRoundTripper 834 modifyResponse func(*http.Response) error 835 }{ 836 { 837 name: "default", 838 wantCode: http.StatusBadGateway, 839 }, 840 { 841 name: "errorhandler", 842 wantCode: http.StatusTeapot, 843 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 844 }, 845 { 846 name: "modifyresponse_noerr", 847 transport: staticResponseRoundTripper{ 848 &http.Response{StatusCode: 345, Body: http.NoBody}, 849 }, 850 modifyResponse: func(res *http.Response) error { 851 res.StatusCode++ 852 return nil 853 }, 854 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 855 wantCode: 346, 856 }, 857 { 858 name: "modifyresponse_err", 859 transport: staticResponseRoundTripper{ 860 &http.Response{StatusCode: 345, Body: http.NoBody}, 861 }, 862 modifyResponse: func(res *http.Response) error { 863 res.StatusCode++ 864 return errors.New("some error to trigger errorHandler") 865 }, 866 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 867 wantCode: http.StatusTeapot, 868 }, 869 } 870 871 for _, tt := range tests { 872 t.Run(tt.name, func(t *testing.T) { 873 target := &url.URL{ 874 Scheme: "http", 875 Host: "dummy.tld", 876 Path: "/", 877 } 878 rproxy := NewSingleHostReverseProxy(target) 879 rproxy.Transport = tt.transport 880 rproxy.ModifyResponse = tt.modifyResponse 881 if rproxy.Transport == nil { 882 rproxy.Transport = failingRoundTripper{} 883 } 884 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 885 if tt.errorHandler != nil { 886 rproxy.ErrorHandler = tt.errorHandler 887 } 888 frontendProxy := httptest.NewServer(rproxy) 889 defer frontendProxy.Close() 890 891 resp, err := http.Get(frontendProxy.URL + "/test") 892 if err != nil { 893 t.Fatalf("failed to reach proxy: %v", err) 894 } 895 if g, e := resp.StatusCode, tt.wantCode; g != e { 896 t.Errorf("got res.StatusCode %d; expected %d", g, e) 897 } 898 resp.Body.Close() 899 }) 900 } 901 } 902 903 // Issue 16659: log errors from short read 904 func TestReverseProxy_CopyBuffer(t *testing.T) { 905 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 906 out := "this call was relayed by the reverse proxy" 907 // Coerce a wrong content length to induce io.UnexpectedEOF 908 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 909 fmt.Fprintln(w, out) 910 })) 911 defer backendServer.Close() 912 913 rpURL, err := url.Parse(backendServer.URL) 914 if err != nil { 915 t.Fatal(err) 916 } 917 918 var proxyLog bytes.Buffer 919 rproxy := NewSingleHostReverseProxy(rpURL) 920 rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) 921 donec := make(chan bool, 1) 922 frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 923 defer func() { donec <- true }() 924 rproxy.ServeHTTP(w, r) 925 })) 926 defer frontendProxy.Close() 927 928 if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { 929 t.Fatalf("want non-nil error") 930 } 931 // The race detector complains about the proxyLog usage in logf in copyBuffer 932 // and our usage below with proxyLog.Bytes() so we're explicitly using a 933 // channel to ensure that the ReverseProxy's ServeHTTP is done before we 934 // continue after Get. 935 <-donec 936 937 expected := []string{ 938 "EOF", 939 "read", 940 } 941 for _, phrase := range expected { 942 if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { 943 t.Errorf("expected log to contain phrase %q", phrase) 944 } 945 } 946 } 947 948 type staticTransport struct { 949 res *http.Response 950 } 951 952 func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { 953 return t.res, nil 954 } 955 956 func BenchmarkServeHTTP(b *testing.B) { 957 res := &http.Response{ 958 StatusCode: 200, 959 Body: io.NopCloser(strings.NewReader("")), 960 } 961 proxy := &ReverseProxy{ 962 Director: func(*http.Request) {}, 963 Transport: &staticTransport{res}, 964 } 965 966 w := httptest.NewRecorder() 967 r := httptest.NewRequest("GET", "/", nil) 968 969 b.ReportAllocs() 970 for i := 0; i < b.N; i++ { 971 proxy.ServeHTTP(w, r) 972 } 973 } 974 975 func TestServeHTTPDeepCopy(t *testing.T) { 976 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 977 w.Write([]byte("Hello Gopher!")) 978 })) 979 defer backend.Close() 980 backendURL, err := url.Parse(backend.URL) 981 if err != nil { 982 t.Fatal(err) 983 } 984 985 type result struct { 986 before, after string 987 } 988 989 resultChan := make(chan result, 1) 990 proxyHandler := NewSingleHostReverseProxy(backendURL) 991 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 992 before := r.URL.String() 993 proxyHandler.ServeHTTP(w, r) 994 after := r.URL.String() 995 resultChan <- result{before: before, after: after} 996 })) 997 defer frontend.Close() 998 999 want := result{before: "/", after: "/"} 1000 1001 res, err := frontend.Client().Get(frontend.URL) 1002 if err != nil { 1003 t.Fatalf("Do: %v", err) 1004 } 1005 res.Body.Close() 1006 1007 got := <-resultChan 1008 if got != want { 1009 t.Errorf("got = %+v; want = %+v", got, want) 1010 } 1011 } 1012 1013 // Issue 18327: verify we always do a deep copy of the Request.Header map 1014 // before any mutations. 1015 func TestClonesRequestHeaders(t *testing.T) { 1016 log.SetOutput(io.Discard) 1017 defer log.SetOutput(os.Stderr) 1018 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1019 req.RemoteAddr = "1.2.3.4:56789" 1020 rp := &ReverseProxy{ 1021 Director: func(req *http.Request) { 1022 req.Header.Set("From-Director", "1") 1023 }, 1024 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { 1025 if v := req.Header.Get("From-Director"); v != "1" { 1026 t.Errorf("From-Directory value = %q; want 1", v) 1027 } 1028 return nil, io.EOF 1029 }), 1030 } 1031 rp.ServeHTTP(httptest.NewRecorder(), req) 1032 1033 if req.Header.Get("From-Director") == "1" { 1034 t.Error("Director header mutation modified caller's request") 1035 } 1036 if req.Header.Get("X-Forwarded-For") != "" { 1037 t.Error("X-Forward-For header mutation modified caller's request") 1038 } 1039 1040 } 1041 1042 type roundTripperFunc func(req *http.Request) (*http.Response, error) 1043 1044 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 1045 return fn(req) 1046 } 1047 1048 func TestModifyResponseClosesBody(t *testing.T) { 1049 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1050 req.RemoteAddr = "1.2.3.4:56789" 1051 closeCheck := new(checkCloser) 1052 logBuf := new(bytes.Buffer) 1053 outErr := errors.New("ModifyResponse error") 1054 rp := &ReverseProxy{ 1055 Director: func(req *http.Request) {}, 1056 Transport: &staticTransport{&http.Response{ 1057 StatusCode: 200, 1058 Body: closeCheck, 1059 }}, 1060 ErrorLog: log.New(logBuf, "", 0), 1061 ModifyResponse: func(*http.Response) error { 1062 return outErr 1063 }, 1064 } 1065 rec := httptest.NewRecorder() 1066 rp.ServeHTTP(rec, req) 1067 res := rec.Result() 1068 if g, e := res.StatusCode, http.StatusBadGateway; g != e { 1069 t.Errorf("got res.StatusCode %d; expected %d", g, e) 1070 } 1071 if !closeCheck.closed { 1072 t.Errorf("body should have been closed") 1073 } 1074 if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { 1075 t.Errorf("ErrorLog %q does not contain %q", g, e) 1076 } 1077 } 1078 1079 type checkCloser struct { 1080 closed bool 1081 } 1082 1083 func (cc *checkCloser) Close() error { 1084 cc.closed = true 1085 return nil 1086 } 1087 1088 func (cc *checkCloser) Read(b []byte) (int, error) { 1089 return len(b), nil 1090 } 1091 1092 // Issue 23643: panic on body copy error 1093 func TestReverseProxy_PanicBodyError(t *testing.T) { 1094 log.SetOutput(io.Discard) 1095 defer log.SetOutput(os.Stderr) 1096 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1097 out := "this call was relayed by the reverse proxy" 1098 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1099 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1100 fmt.Fprintln(w, out) 1101 })) 1102 defer backendServer.Close() 1103 1104 rpURL, err := url.Parse(backendServer.URL) 1105 if err != nil { 1106 t.Fatal(err) 1107 } 1108 1109 rproxy := NewSingleHostReverseProxy(rpURL) 1110 1111 // Ensure that the handler panics when the body read encounters an 1112 // io.ErrUnexpectedEOF 1113 defer func() { 1114 err := recover() 1115 if err == nil { 1116 t.Fatal("handler should have panicked") 1117 } 1118 if err != http.ErrAbortHandler { 1119 t.Fatal("expected ErrAbortHandler, got", err) 1120 } 1121 }() 1122 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1123 rproxy.ServeHTTP(httptest.NewRecorder(), req) 1124 } 1125 1126 // Issue #46866: panic without closing incoming request body causes a panic 1127 func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) { 1128 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1129 out := "this call was relayed by the reverse proxy" 1130 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1131 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1132 fmt.Fprintln(w, out) 1133 })) 1134 defer backend.Close() 1135 backendURL, err := url.Parse(backend.URL) 1136 if err != nil { 1137 t.Fatal(err) 1138 } 1139 proxyHandler := NewSingleHostReverseProxy(backendURL) 1140 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1141 frontend := httptest.NewServer(proxyHandler) 1142 defer frontend.Close() 1143 frontendClient := frontend.Client() 1144 1145 var wg sync.WaitGroup 1146 for i := 0; i < 2; i++ { 1147 wg.Add(1) 1148 go func() { 1149 defer wg.Done() 1150 for j := 0; j < 10; j++ { 1151 const reqLen = 6 * 1024 * 1024 1152 req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) 1153 req.ContentLength = reqLen 1154 resp, _ := frontendClient.Transport.RoundTrip(req) 1155 if resp != nil { 1156 io.Copy(io.Discard, resp.Body) 1157 resp.Body.Close() 1158 } 1159 } 1160 }() 1161 } 1162 wg.Wait() 1163 } 1164 1165 func TestSelectFlushInterval(t *testing.T) { 1166 tests := []struct { 1167 name string 1168 p *ReverseProxy 1169 res *http.Response 1170 want time.Duration 1171 }{ 1172 { 1173 name: "default", 1174 res: &http.Response{}, 1175 p: &ReverseProxy{FlushInterval: 123}, 1176 want: 123, 1177 }, 1178 { 1179 name: "server-sent events overrides non-zero", 1180 res: &http.Response{ 1181 Header: http.Header{ 1182 "Content-Type": {"text/event-stream"}, 1183 }, 1184 }, 1185 p: &ReverseProxy{FlushInterval: 123}, 1186 want: -1, 1187 }, 1188 { 1189 name: "server-sent events overrides zero", 1190 res: &http.Response{ 1191 Header: http.Header{ 1192 "Content-Type": {"text/event-stream"}, 1193 }, 1194 }, 1195 p: &ReverseProxy{FlushInterval: 0}, 1196 want: -1, 1197 }, 1198 { 1199 name: "Content-Length: -1, overrides non-zero", 1200 res: &http.Response{ 1201 ContentLength: -1, 1202 }, 1203 p: &ReverseProxy{FlushInterval: 123}, 1204 want: -1, 1205 }, 1206 { 1207 name: "Content-Length: -1, overrides zero", 1208 res: &http.Response{ 1209 ContentLength: -1, 1210 }, 1211 p: &ReverseProxy{FlushInterval: 0}, 1212 want: -1, 1213 }, 1214 } 1215 for _, tt := range tests { 1216 t.Run(tt.name, func(t *testing.T) { 1217 got := tt.p.flushInterval(tt.res) 1218 if got != tt.want { 1219 t.Errorf("flushLatency = %v; want %v", got, tt.want) 1220 } 1221 }) 1222 } 1223 } 1224 1225 func TestReverseProxyWebSocket(t *testing.T) { 1226 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1227 if upgradeType(r.Header) != "websocket" { 1228 t.Error("unexpected backend request") 1229 http.Error(w, "unexpected request", 400) 1230 return 1231 } 1232 c, _, err := w.(http.Hijacker).Hijack() 1233 if err != nil { 1234 t.Error(err) 1235 return 1236 } 1237 defer c.Close() 1238 io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") 1239 bs := bufio.NewScanner(c) 1240 if !bs.Scan() { 1241 t.Errorf("backend failed to read line from client: %v", bs.Err()) 1242 return 1243 } 1244 fmt.Fprintf(c, "backend got %q\n", bs.Text()) 1245 })) 1246 defer backendServer.Close() 1247 1248 backURL, _ := url.Parse(backendServer.URL) 1249 rproxy := NewSingleHostReverseProxy(backURL) 1250 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1251 rproxy.ModifyResponse = func(res *http.Response) error { 1252 res.Header.Add("X-Modified", "true") 1253 return nil 1254 } 1255 1256 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1257 rw.Header().Set("X-Header", "X-Value") 1258 rproxy.ServeHTTP(rw, req) 1259 if got, want := rw.Header().Get("X-Modified"), "true"; got != want { 1260 t.Errorf("response writer X-Modified header = %q; want %q", got, want) 1261 } 1262 }) 1263 1264 frontendProxy := httptest.NewServer(handler) 1265 defer frontendProxy.Close() 1266 1267 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1268 req.Header.Set("Connection", "Upgrade") 1269 req.Header.Set("Upgrade", "websocket") 1270 1271 c := frontendProxy.Client() 1272 res, err := c.Do(req) 1273 if err != nil { 1274 t.Fatal(err) 1275 } 1276 if res.StatusCode != 101 { 1277 t.Fatalf("status = %v; want 101", res.Status) 1278 } 1279 1280 got := res.Header.Get("X-Header") 1281 want := "X-Value" 1282 if got != want { 1283 t.Errorf("Header(XHeader) = %q; want %q", got, want) 1284 } 1285 1286 if !ascii.EqualFold(upgradeType(res.Header), "websocket") { 1287 t.Fatalf("not websocket upgrade; got %#v", res.Header) 1288 } 1289 rwc, ok := res.Body.(io.ReadWriteCloser) 1290 if !ok { 1291 t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) 1292 } 1293 defer rwc.Close() 1294 1295 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1296 t.Errorf("response X-Modified header = %q; want %q", got, want) 1297 } 1298 1299 io.WriteString(rwc, "Hello\n") 1300 bs := bufio.NewScanner(rwc) 1301 if !bs.Scan() { 1302 t.Fatalf("Scan: %v", bs.Err()) 1303 } 1304 got = bs.Text() 1305 want = `backend got "Hello"` 1306 if got != want { 1307 t.Errorf("got %#q, want %#q", got, want) 1308 } 1309 } 1310 1311 func TestReverseProxyWebSocketCancellation(t *testing.T) { 1312 n := 5 1313 triggerCancelCh := make(chan bool, n) 1314 nthResponse := func(i int) string { 1315 return fmt.Sprintf("backend response #%d\n", i) 1316 } 1317 terminalMsg := "final message" 1318 1319 cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1320 if g, ws := upgradeType(r.Header), "websocket"; g != ws { 1321 t.Errorf("Unexpected upgrade type %q, want %q", g, ws) 1322 http.Error(w, "Unexpected request", 400) 1323 return 1324 } 1325 conn, bufrw, err := w.(http.Hijacker).Hijack() 1326 if err != nil { 1327 t.Error(err) 1328 return 1329 } 1330 defer conn.Close() 1331 1332 upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" 1333 if _, err := io.WriteString(conn, upgradeMsg); err != nil { 1334 t.Error(err) 1335 return 1336 } 1337 if _, _, err := bufrw.ReadLine(); err != nil { 1338 t.Errorf("Failed to read line from client: %v", err) 1339 return 1340 } 1341 1342 for i := 0; i < n; i++ { 1343 if _, err := bufrw.WriteString(nthResponse(i)); err != nil { 1344 select { 1345 case <-triggerCancelCh: 1346 default: 1347 t.Errorf("Writing response #%d failed: %v", i, err) 1348 } 1349 return 1350 } 1351 bufrw.Flush() 1352 time.Sleep(time.Second) 1353 } 1354 if _, err := bufrw.WriteString(terminalMsg); err != nil { 1355 select { 1356 case <-triggerCancelCh: 1357 default: 1358 t.Errorf("Failed to write terminal message: %v", err) 1359 } 1360 } 1361 bufrw.Flush() 1362 })) 1363 defer cst.Close() 1364 1365 backendURL, _ := url.Parse(cst.URL) 1366 rproxy := NewSingleHostReverseProxy(backendURL) 1367 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1368 rproxy.ModifyResponse = func(res *http.Response) error { 1369 res.Header.Add("X-Modified", "true") 1370 return nil 1371 } 1372 1373 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1374 rw.Header().Set("X-Header", "X-Value") 1375 ctx, cancel := context.WithCancel(req.Context()) 1376 go func() { 1377 <-triggerCancelCh 1378 cancel() 1379 }() 1380 rproxy.ServeHTTP(rw, req.WithContext(ctx)) 1381 }) 1382 1383 frontendProxy := httptest.NewServer(handler) 1384 defer frontendProxy.Close() 1385 1386 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1387 req.Header.Set("Connection", "Upgrade") 1388 req.Header.Set("Upgrade", "websocket") 1389 1390 res, err := frontendProxy.Client().Do(req) 1391 if err != nil { 1392 t.Fatalf("Dialing to frontend proxy: %v", err) 1393 } 1394 defer res.Body.Close() 1395 if g, w := res.StatusCode, 101; g != w { 1396 t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) 1397 } 1398 1399 if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { 1400 t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) 1401 } 1402 1403 if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) { 1404 t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) 1405 } 1406 1407 rwc, ok := res.Body.(io.ReadWriteCloser) 1408 if !ok { 1409 t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) 1410 } 1411 1412 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1413 t.Errorf("response X-Modified header = %q; want %q", got, want) 1414 } 1415 1416 if _, err := io.WriteString(rwc, "Hello\n"); err != nil { 1417 t.Fatalf("Failed to write first message: %v", err) 1418 } 1419 1420 // Read loop. 1421 1422 br := bufio.NewReader(rwc) 1423 for { 1424 line, err := br.ReadString('\n') 1425 switch { 1426 case line == terminalMsg: // this case before "err == io.EOF" 1427 t.Fatalf("The websocket request was not canceled, unfortunately!") 1428 1429 case err == io.EOF: 1430 return 1431 1432 case err != nil: 1433 t.Fatalf("Unexpected error: %v", err) 1434 1435 case line == nthResponse(0): // We've gotten the first response back 1436 // Let's trigger a cancel. 1437 close(triggerCancelCh) 1438 } 1439 } 1440 } 1441 1442 func TestUnannouncedTrailer(t *testing.T) { 1443 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1444 w.WriteHeader(http.StatusOK) 1445 w.(http.Flusher).Flush() 1446 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 1447 })) 1448 defer backend.Close() 1449 backendURL, err := url.Parse(backend.URL) 1450 if err != nil { 1451 t.Fatal(err) 1452 } 1453 proxyHandler := NewSingleHostReverseProxy(backendURL) 1454 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1455 frontend := httptest.NewServer(proxyHandler) 1456 defer frontend.Close() 1457 frontendClient := frontend.Client() 1458 1459 res, err := frontendClient.Get(frontend.URL) 1460 if err != nil { 1461 t.Fatalf("Get: %v", err) 1462 } 1463 1464 io.ReadAll(res.Body) 1465 1466 if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { 1467 t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) 1468 } 1469 1470 } 1471 1472 func TestSingleJoinSlash(t *testing.T) { 1473 tests := []struct { 1474 slasha string 1475 slashb string 1476 expected string 1477 }{ 1478 {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1479 {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1480 {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, 1481 {"https://www.google.com", "", "https://www.google.com/"}, 1482 {"", "favicon.ico", "/favicon.ico"}, 1483 } 1484 for _, tt := range tests { 1485 if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { 1486 t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", 1487 tt.slasha, 1488 tt.slashb, 1489 tt.expected, 1490 got) 1491 } 1492 } 1493 } 1494 1495 func TestJoinURLPath(t *testing.T) { 1496 tests := []struct { 1497 a *url.URL 1498 b *url.URL 1499 wantPath string 1500 wantRaw string 1501 }{ 1502 {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, 1503 {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, 1504 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1505 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1506 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, 1507 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, 1508 } 1509 1510 for _, tt := range tests { 1511 p, rp := joinURLPath(tt.a, tt.b) 1512 if p != tt.wantPath || rp != tt.wantRaw { 1513 t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", 1514 tt.a.Path, tt.a.RawPath, 1515 tt.b.Path, tt.b.RawPath, 1516 tt.wantPath, tt.wantRaw, 1517 p, rp) 1518 } 1519 } 1520 }