gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/gmhttp/httputil/reverseproxy_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Reverse proxy tests. 6 7 package httputil 8 9 import ( 10 "bufio" 11 "bytes" 12 "context" 13 "errors" 14 "fmt" 15 "io" 16 "log" 17 "net" 18 "net/url" 19 "os" 20 "reflect" 21 "sort" 22 "strconv" 23 "strings" 24 "sync" 25 "testing" 26 "time" 27 28 http "gitee.com/ks-custle/core-gm/gmhttp" 29 "gitee.com/ks-custle/core-gm/gmhttp/httptest" 30 "gitee.com/ks-custle/core-gm/gmhttp/internal/ascii" 31 ) 32 33 const fakeHopHeader = "X-Fake-Hop-Header-For-Test" 34 35 func init() { 36 inOurTests = true 37 hopHeaders = append(hopHeaders, fakeHopHeader) 38 } 39 40 func TestReverseProxy(t *testing.T) { 41 const backendResponse = "I am the backend" 42 const backendStatus = 404 43 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 44 if r.Method == "GET" && r.FormValue("mode") == "hangup" { 45 c, _, _ := w.(http.Hijacker).Hijack() 46 _ = c.Close() 47 return 48 } 49 if len(r.TransferEncoding) > 0 { 50 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) 51 } 52 if r.Header.Get("X-Forwarded-For") == "" { 53 t.Errorf("didn't get X-Forwarded-For header") 54 } 55 if c := r.Header.Get("Connection"); c != "" { 56 t.Errorf("handler got Connection header value %q", c) 57 } 58 if c := r.Header.Get("Te"); c != "trailers" { 59 t.Errorf("handler got Te header value %q; want 'trailers'", c) 60 } 61 if c := r.Header.Get("Upgrade"); c != "" { 62 t.Errorf("handler got Upgrade header value %q", c) 63 } 64 if c := r.Header.Get("Proxy-Connection"); c != "" { 65 t.Errorf("handler got Proxy-Connection header value %q", c) 66 } 67 if g, e := r.Host, "some-name"; g != e { 68 t.Errorf("backend got Host header %q, want %q", g, e) 69 } 70 w.Header().Set("Trailers", "not a special header field name") 71 w.Header().Set("Trailer", "X-Trailer") 72 w.Header().Set("X-Foo", "bar") 73 w.Header().Set("Upgrade", "foo") 74 w.Header().Set(fakeHopHeader, "foo") 75 w.Header().Add("X-Multi-Value", "foo") 76 w.Header().Add("X-Multi-Value", "bar") 77 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) 78 w.WriteHeader(backendStatus) 79 _, _ = w.Write([]byte(backendResponse)) 80 w.Header().Set("X-Trailer", "trailer_value") 81 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 82 })) 83 defer backend.Close() 84 backendURL, err := url.Parse(backend.URL) 85 if err != nil { 86 t.Fatal(err) 87 } 88 proxyHandler := NewSingleHostReverseProxy(backendURL) 89 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 90 frontend := httptest.NewServer(proxyHandler) 91 defer frontend.Close() 92 frontendClient := frontend.Client() 93 94 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 95 getReq.Host = "some-name" 96 getReq.Header.Set("Connection", "close, TE") 97 getReq.Header.Add("Te", "foo") 98 getReq.Header.Add("Te", "bar, trailers") 99 getReq.Header.Set("Proxy-Connection", "should be deleted") 100 getReq.Header.Set("Upgrade", "foo") 101 getReq.Close = true 102 res, err := frontendClient.Do(getReq) 103 if err != nil { 104 t.Fatalf("Get: %v", err) 105 } 106 if g, e := res.StatusCode, backendStatus; g != e { 107 t.Errorf("got res.StatusCode %d; expected %d", g, e) 108 } 109 if g, e := res.Header.Get("X-Foo"), "bar"; g != e { 110 t.Errorf("got X-Foo %q; expected %q", g, e) 111 } 112 if c := res.Header.Get(fakeHopHeader); c != "" { 113 t.Errorf("got %s header value %q", fakeHopHeader, c) 114 } 115 if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { 116 t.Errorf("header Trailers = %q; want %q", g, e) 117 } 118 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { 119 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) 120 } 121 if g, e := len(res.Header["Set-Cookie"]), 1; g != e { 122 t.Fatalf("got %d SetCookies, want %d", g, e) 123 } 124 if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { 125 t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) 126 } 127 if cookie := res.Cookies()[0]; cookie.Name != "flavor" { 128 t.Errorf("unexpected cookie %q", cookie.Name) 129 } 130 bodyBytes, _ := io.ReadAll(res.Body) 131 if g, e := string(bodyBytes), backendResponse; g != e { 132 t.Errorf("got body %q; expected %q", g, e) 133 } 134 if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { 135 t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) 136 } 137 if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { 138 t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) 139 } 140 141 // Test that a backend failing to be reached or one which doesn't return 142 // a response results in a StatusBadGateway. 143 getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) 144 getReq.Close = true 145 res, err = frontendClient.Do(getReq) 146 if err != nil { 147 t.Fatal(err) 148 } 149 _ = res.Body.Close() 150 if res.StatusCode != http.StatusBadGateway { 151 t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) 152 } 153 154 } 155 156 // Issue 16875: remove any proxied headers mentioned in the "Connection" 157 // header value. 158 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { 159 const fakeConnectionToken = "X-Fake-Connection-Token" 160 const backendResponse = "I am the backend" 161 162 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 163 // in the Request's Connection header. 164 const someConnHeader = "X-Some-Conn-Header" 165 166 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 167 if c := r.Header.Get("Connection"); c != "" { 168 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 169 } 170 if c := r.Header.Get(fakeConnectionToken); c != "" { 171 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 172 } 173 if c := r.Header.Get(someConnHeader); c != "" { 174 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 175 } 176 w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) 177 w.Header().Add("Connection", someConnHeader) 178 w.Header().Set(someConnHeader, "should be deleted") 179 w.Header().Set(fakeConnectionToken, "should be deleted") 180 _, _ = io.WriteString(w, backendResponse) 181 })) 182 defer backend.Close() 183 backendURL, err := url.Parse(backend.URL) 184 if err != nil { 185 t.Fatal(err) 186 } 187 proxyHandler := NewSingleHostReverseProxy(backendURL) 188 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 189 proxyHandler.ServeHTTP(w, r) 190 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 191 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 192 } 193 if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { 194 t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") 195 } 196 c := r.Header["Connection"] 197 var cf []string 198 for _, f := range c { 199 for _, sf := range strings.Split(f, ",") { 200 if sf = strings.TrimSpace(sf); sf != "" { 201 cf = append(cf, sf) 202 } 203 } 204 } 205 sort.Strings(cf) 206 expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} 207 sort.Strings(expectedValues) 208 if !reflect.DeepEqual(cf, expectedValues) { 209 t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) 210 } 211 })) 212 defer frontend.Close() 213 214 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 215 getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) 216 getReq.Header.Add("Connection", someConnHeader) 217 getReq.Header.Set(someConnHeader, "should be deleted") 218 getReq.Header.Set(fakeConnectionToken, "should be deleted") 219 res, err := frontend.Client().Do(getReq) 220 if err != nil { 221 t.Fatalf("Get: %v", err) 222 } 223 defer func(Body io.ReadCloser) { 224 _ = Body.Close() 225 }(res.Body) 226 bodyBytes, err := io.ReadAll(res.Body) 227 if err != nil { 228 t.Fatalf("reading body: %v", err) 229 } 230 if got, want := string(bodyBytes), backendResponse; got != want { 231 t.Errorf("got body %q; want %q", got, want) 232 } 233 if c := res.Header.Get("Connection"); c != "" { 234 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 235 } 236 if c := res.Header.Get(someConnHeader); c != "" { 237 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 238 } 239 if c := res.Header.Get(fakeConnectionToken); c != "" { 240 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 241 } 242 } 243 244 func TestReverseProxyStripEmptyConnection(t *testing.T) { 245 // See Issue 46313. 246 const backendResponse = "I am the backend" 247 248 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 249 // in the Request's Connection header. 250 const someConnHeader = "X-Some-Conn-Header" 251 252 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 253 if c := r.Header.Values("Connection"); len(c) != 0 { 254 t.Errorf("handler got header %q = %v; want empty", "Connection", c) 255 } 256 if c := r.Header.Get(someConnHeader); c != "" { 257 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 258 } 259 w.Header().Add("Connection", "") 260 w.Header().Add("Connection", someConnHeader) 261 w.Header().Set(someConnHeader, "should be deleted") 262 _, _ = io.WriteString(w, backendResponse) 263 })) 264 defer backend.Close() 265 backendURL, err := url.Parse(backend.URL) 266 if err != nil { 267 t.Fatal(err) 268 } 269 proxyHandler := NewSingleHostReverseProxy(backendURL) 270 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 271 proxyHandler.ServeHTTP(w, r) 272 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 273 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 274 } 275 })) 276 defer frontend.Close() 277 278 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 279 getReq.Header.Add("Connection", "") 280 getReq.Header.Add("Connection", someConnHeader) 281 getReq.Header.Set(someConnHeader, "should be deleted") 282 res, err := frontend.Client().Do(getReq) 283 if err != nil { 284 t.Fatalf("Get: %v", err) 285 } 286 defer func(Body io.ReadCloser) { 287 _ = Body.Close() 288 }(res.Body) 289 bodyBytes, err := io.ReadAll(res.Body) 290 if err != nil { 291 t.Fatalf("reading body: %v", err) 292 } 293 if got, want := string(bodyBytes), backendResponse; got != want { 294 t.Errorf("got body %q; want %q", got, want) 295 } 296 if c := res.Header.Get("Connection"); c != "" { 297 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 298 } 299 if c := res.Header.Get(someConnHeader); c != "" { 300 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 301 } 302 } 303 304 func TestXForwardedFor(t *testing.T) { 305 const prevForwardedFor = "client ip" 306 const backendResponse = "I am the backend" 307 const backendStatus = 404 308 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 309 if r.Header.Get("X-Forwarded-For") == "" { 310 t.Errorf("didn't get X-Forwarded-For header") 311 } 312 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { 313 t.Errorf("X-Forwarded-For didn't contain prior data") 314 } 315 w.WriteHeader(backendStatus) 316 _, _ = w.Write([]byte(backendResponse)) 317 })) 318 defer backend.Close() 319 backendURL, err := url.Parse(backend.URL) 320 if err != nil { 321 t.Fatal(err) 322 } 323 proxyHandler := NewSingleHostReverseProxy(backendURL) 324 frontend := httptest.NewServer(proxyHandler) 325 defer frontend.Close() 326 327 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 328 getReq.Host = "some-name" 329 getReq.Header.Set("Connection", "close") 330 getReq.Header.Set("X-Forwarded-For", prevForwardedFor) 331 getReq.Close = true 332 res, err := frontend.Client().Do(getReq) 333 if err != nil { 334 t.Fatalf("Get: %v", err) 335 } 336 if g, e := res.StatusCode, backendStatus; g != e { 337 t.Errorf("got res.StatusCode %d; expected %d", g, e) 338 } 339 bodyBytes, _ := io.ReadAll(res.Body) 340 if g, e := string(bodyBytes), backendResponse; g != e { 341 t.Errorf("got body %q; expected %q", g, e) 342 } 343 } 344 345 // Issue 38079: don't append to X-Forwarded-For if it's present but nil 346 func TestXForwardedFor_Omit(t *testing.T) { 347 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 348 if v := r.Header.Get("X-Forwarded-For"); v != "" { 349 t.Errorf("got X-Forwarded-For header: %q", v) 350 } 351 _, _ = w.Write([]byte("hi")) 352 })) 353 defer backend.Close() 354 backendURL, err := url.Parse(backend.URL) 355 if err != nil { 356 t.Fatal(err) 357 } 358 proxyHandler := NewSingleHostReverseProxy(backendURL) 359 frontend := httptest.NewServer(proxyHandler) 360 defer frontend.Close() 361 362 oldDirector := proxyHandler.Director 363 proxyHandler.Director = func(r *http.Request) { 364 r.Header["X-Forwarded-For"] = nil 365 oldDirector(r) 366 } 367 368 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 369 getReq.Host = "some-name" 370 getReq.Close = true 371 res, err := frontend.Client().Do(getReq) 372 if err != nil { 373 t.Fatalf("Get: %v", err) 374 } 375 _ = res.Body.Close() 376 } 377 378 var proxyQueryTests = []struct { 379 baseSuffix string // suffix to add to backend URL 380 reqSuffix string // suffix to add to frontend's request URL 381 want string // what backend should see for final request URL (without ?) 382 }{ 383 {"", "", ""}, 384 {"?sta=tic", "?us=er", "sta=tic&us=er"}, 385 {"", "?us=er", "us=er"}, 386 {"?sta=tic", "", "sta=tic"}, 387 } 388 389 func TestReverseProxyQuery(t *testing.T) { 390 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 391 w.Header().Set("X-Got-Query", r.URL.RawQuery) 392 _, _ = w.Write([]byte("hi")) 393 })) 394 defer backend.Close() 395 396 for i, tt := range proxyQueryTests { 397 backendURL, err := url.Parse(backend.URL + tt.baseSuffix) 398 if err != nil { 399 t.Fatal(err) 400 } 401 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) 402 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) 403 req.Close = true 404 res, err := frontend.Client().Do(req) 405 if err != nil { 406 t.Fatalf("%d. Get: %v", i, err) 407 } 408 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { 409 t.Errorf("%d. got query %q; expected %q", i, g, e) 410 } 411 _ = res.Body.Close() 412 frontend.Close() 413 } 414 } 415 416 func TestReverseProxyFlushInterval(t *testing.T) { 417 const expected = "hi" 418 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 419 _, _ = w.Write([]byte(expected)) 420 })) 421 defer backend.Close() 422 423 backendURL, err := url.Parse(backend.URL) 424 if err != nil { 425 t.Fatal(err) 426 } 427 428 proxyHandler := NewSingleHostReverseProxy(backendURL) 429 proxyHandler.FlushInterval = time.Microsecond 430 431 frontend := httptest.NewServer(proxyHandler) 432 defer frontend.Close() 433 434 req, _ := http.NewRequest("GET", frontend.URL, nil) 435 req.Close = true 436 res, err := frontend.Client().Do(req) 437 if err != nil { 438 t.Fatalf("Get: %v", err) 439 } 440 defer func(Body io.ReadCloser) { 441 _ = Body.Close() 442 }(res.Body) 443 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { 444 t.Errorf("got body %q; expected %q", bodyBytes, expected) 445 } 446 } 447 448 func TestReverseProxyFlushIntervalHeaders(t *testing.T) { 449 const expected = "hi" 450 stopCh := make(chan struct{}) 451 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 452 w.Header().Add("MyHeader", expected) 453 w.WriteHeader(200) 454 w.(http.Flusher).Flush() 455 <-stopCh 456 })) 457 defer backend.Close() 458 defer close(stopCh) 459 460 backendURL, err := url.Parse(backend.URL) 461 if err != nil { 462 t.Fatal(err) 463 } 464 465 proxyHandler := NewSingleHostReverseProxy(backendURL) 466 proxyHandler.FlushInterval = time.Microsecond 467 468 frontend := httptest.NewServer(proxyHandler) 469 defer frontend.Close() 470 471 req, _ := http.NewRequest("GET", frontend.URL, nil) 472 req.Close = true 473 474 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) 475 defer cancel() 476 req = req.WithContext(ctx) 477 478 res, err := frontend.Client().Do(req) 479 if err != nil { 480 t.Fatalf("Get: %v", err) 481 } 482 defer func(Body io.ReadCloser) { 483 _ = Body.Close() 484 }(res.Body) 485 486 if res.Header.Get("MyHeader") != expected { 487 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) 488 } 489 } 490 491 func TestReverseProxyCancellation(t *testing.T) { 492 const backendResponse = "I am the backend" 493 494 reqInFlight := make(chan struct{}) 495 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 496 close(reqInFlight) // cause the client to cancel its request 497 498 select { 499 case <-time.After(10 * time.Second): 500 // Note: this should only happen in broken implementations, and the 501 // closenotify case should be instantaneous. 502 t.Error("Handler never saw CloseNotify") 503 return 504 case <-w.(http.CloseNotifier).CloseNotify(): 505 } 506 507 w.WriteHeader(http.StatusOK) 508 _, _ = w.Write([]byte(backendResponse)) 509 })) 510 511 defer backend.Close() 512 513 backend.Config.ErrorLog = log.New(io.Discard, "", 0) 514 515 backendURL, err := url.Parse(backend.URL) 516 if err != nil { 517 t.Fatal(err) 518 } 519 520 proxyHandler := NewSingleHostReverseProxy(backendURL) 521 522 // Discards errors of the form: 523 // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection 524 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) 525 526 frontend := httptest.NewServer(proxyHandler) 527 defer frontend.Close() 528 frontendClient := frontend.Client() 529 530 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 531 go func() { 532 <-reqInFlight 533 frontendClient.Transport.(*http.Transport).CancelRequest(getReq) 534 }() 535 res, err := frontendClient.Do(getReq) 536 if res != nil { 537 t.Errorf("got response %v; want nil", res.Status) 538 } 539 if err == nil { 540 // This should be an error like: 541 // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079: 542 // use of closed network connection 543 t.Error("Server.Client().Do() returned nil error; want non-nil error") 544 } 545 } 546 547 func req(t *testing.T, v string) *http.Request { 548 req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) 549 if err != nil { 550 t.Fatal(err) 551 } 552 return req 553 } 554 555 // Issue 12344 556 func TestNilBody(t *testing.T) { 557 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 558 _, _ = w.Write([]byte("hi")) 559 })) 560 defer backend.Close() 561 562 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 563 backURL, _ := url.Parse(backend.URL) 564 rp := NewSingleHostReverseProxy(backURL) 565 r := req(t, "GET / HTTP/1.0\r\n\r\n") 566 r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working 567 rp.ServeHTTP(w, r) 568 })) 569 defer frontend.Close() 570 571 res, err := http.Get(frontend.URL) 572 if err != nil { 573 t.Fatal(err) 574 } 575 defer func(Body io.ReadCloser) { 576 _ = Body.Close() 577 }(res.Body) 578 slurp, err := io.ReadAll(res.Body) 579 if err != nil { 580 t.Fatal(err) 581 } 582 if string(slurp) != "hi" { 583 t.Errorf("Got %q; want %q", slurp, "hi") 584 } 585 } 586 587 // Issue 15524 588 func TestUserAgentHeader(t *testing.T) { 589 const explicitUA = "explicit UA" 590 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 591 if r.URL.Path == "/noua" { 592 if c := r.Header.Get("User-Agent"); c != "" { 593 t.Errorf("handler got non-empty User-Agent header %q", c) 594 } 595 return 596 } 597 if c := r.Header.Get("User-Agent"); c != explicitUA { 598 t.Errorf("handler got unexpected User-Agent header %q", c) 599 } 600 })) 601 defer backend.Close() 602 backendURL, err := url.Parse(backend.URL) 603 if err != nil { 604 t.Fatal(err) 605 } 606 proxyHandler := NewSingleHostReverseProxy(backendURL) 607 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 608 frontend := httptest.NewServer(proxyHandler) 609 defer frontend.Close() 610 frontendClient := frontend.Client() 611 612 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 613 getReq.Header.Set("User-Agent", explicitUA) 614 getReq.Close = true 615 res, err := frontendClient.Do(getReq) 616 if err != nil { 617 t.Fatalf("Get: %v", err) 618 } 619 _ = res.Body.Close() 620 621 getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) 622 getReq.Header.Set("User-Agent", "") 623 getReq.Close = true 624 res, err = frontendClient.Do(getReq) 625 if err != nil { 626 t.Fatalf("Get: %v", err) 627 } 628 _ = res.Body.Close() 629 } 630 631 type bufferPool struct { 632 get func() []byte 633 put func([]byte) 634 } 635 636 func (bp bufferPool) Get() []byte { return bp.get() } 637 func (bp bufferPool) Put(v []byte) { bp.put(v) } 638 639 func TestReverseProxyGetPutBuffer(t *testing.T) { 640 const msg = "hi" 641 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 642 _, _ = io.WriteString(w, msg) 643 })) 644 defer backend.Close() 645 646 backendURL, err := url.Parse(backend.URL) 647 if err != nil { 648 t.Fatal(err) 649 } 650 651 var ( 652 mu sync.Mutex 653 logTest []string 654 ) 655 addLog := func(event string) { 656 mu.Lock() 657 defer mu.Unlock() 658 logTest = append(logTest, event) 659 } 660 rp := NewSingleHostReverseProxy(backendURL) 661 const size = 1234 662 rp.BufferPool = bufferPool{ 663 get: func() []byte { 664 addLog("getBuf") 665 return make([]byte, size) 666 }, 667 put: func(p []byte) { 668 addLog("putBuf-" + strconv.Itoa(len(p))) 669 }, 670 } 671 frontend := httptest.NewServer(rp) 672 defer frontend.Close() 673 674 req, _ := http.NewRequest("GET", frontend.URL, nil) 675 req.Close = true 676 res, err := frontend.Client().Do(req) 677 if err != nil { 678 t.Fatalf("Get: %v", err) 679 } 680 slurp, err := io.ReadAll(res.Body) 681 _ = res.Body.Close() 682 if err != nil { 683 t.Fatalf("reading body: %v", err) 684 } 685 if string(slurp) != msg { 686 t.Errorf("msg = %q; want %q", slurp, msg) 687 } 688 wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} 689 mu.Lock() 690 defer mu.Unlock() 691 if !reflect.DeepEqual(logTest, wantLog) { 692 t.Errorf("Log events = %q; want %q", logTest, wantLog) 693 } 694 } 695 696 func TestReverseProxy_Post(t *testing.T) { 697 const backendResponse = "I am the backend" 698 const backendStatus = 200 699 var requestBody = bytes.Repeat([]byte("a"), 1<<20) 700 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 701 slurp, err := io.ReadAll(r.Body) 702 if err != nil { 703 t.Errorf("Backend body read = %v", err) 704 } 705 if len(slurp) != len(requestBody) { 706 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) 707 } 708 if !bytes.Equal(slurp, requestBody) { 709 t.Error("Backend read wrong request body.") // 1MB; omitting details 710 } 711 _, _ = w.Write([]byte(backendResponse)) 712 })) 713 defer backend.Close() 714 backendURL, err := url.Parse(backend.URL) 715 if err != nil { 716 t.Fatal(err) 717 } 718 proxyHandler := NewSingleHostReverseProxy(backendURL) 719 frontend := httptest.NewServer(proxyHandler) 720 defer frontend.Close() 721 722 postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) 723 res, err := frontend.Client().Do(postReq) 724 if err != nil { 725 t.Fatalf("Do: %v", err) 726 } 727 if g, e := res.StatusCode, backendStatus; g != e { 728 t.Errorf("got res.StatusCode %d; expected %d", g, e) 729 } 730 bodyBytes, _ := io.ReadAll(res.Body) 731 if g, e := string(bodyBytes), backendResponse; g != e { 732 t.Errorf("got body %q; expected %q", g, e) 733 } 734 } 735 736 type RoundTripperFunc func(*http.Request) (*http.Response, error) 737 738 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 739 return fn(req) 740 } 741 742 // Issue 16036: send a Request with a nil Body when possible 743 func TestReverseProxy_NilBody(t *testing.T) { 744 backendURL, _ := url.Parse("http://fake.tld/") 745 proxyHandler := NewSingleHostReverseProxy(backendURL) 746 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 747 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 748 if req.Body != nil { 749 t.Error("Body != nil; want a nil Body") 750 } 751 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 752 }) 753 frontend := httptest.NewServer(proxyHandler) 754 defer frontend.Close() 755 756 res, err := frontend.Client().Get(frontend.URL) 757 if err != nil { 758 t.Fatal(err) 759 } 760 defer func(Body io.ReadCloser) { 761 _ = Body.Close() 762 }(res.Body) 763 if res.StatusCode != 502 { 764 t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) 765 } 766 } 767 768 // Issue 33142: always allocate the request headers 769 func TestReverseProxy_AllocatedHeader(t *testing.T) { 770 proxyHandler := new(ReverseProxy) 771 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 772 proxyHandler.Director = func(*http.Request) {} // noop 773 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 774 if req.Header == nil { 775 t.Error("Header == nil; want a non-nil Header") 776 } 777 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 778 }) 779 780 proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ 781 Method: "GET", 782 URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, 783 Proto: "HTTP/1.0", 784 ProtoMajor: 1, 785 }) 786 } 787 788 // Issue 14237. Test ModifyResponse and that an error from it 789 // causes the proxy to return StatusBadGateway, or StatusOK otherwise. 790 func TestReverseProxyModifyResponse(t *testing.T) { 791 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 792 w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) 793 })) 794 defer backendServer.Close() 795 796 rpURL, _ := url.Parse(backendServer.URL) 797 rproxy := NewSingleHostReverseProxy(rpURL) 798 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 799 rproxy.ModifyResponse = func(resp *http.Response) error { 800 if resp.Header.Get("X-Hit-Mod") != "true" { 801 return fmt.Errorf("tried to by-pass proxy") 802 } 803 return nil 804 } 805 806 frontendProxy := httptest.NewServer(rproxy) 807 defer frontendProxy.Close() 808 809 tests := []struct { 810 url string 811 wantCode int 812 }{ 813 {frontendProxy.URL + "/mod", http.StatusOK}, 814 {frontendProxy.URL + "/schedule", http.StatusBadGateway}, 815 } 816 817 for i, tt := range tests { 818 resp, err := http.Get(tt.url) 819 if err != nil { 820 t.Fatalf("failed to reach proxy: %v", err) 821 } 822 if g, e := resp.StatusCode, tt.wantCode; g != e { 823 t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) 824 } 825 _ = resp.Body.Close() 826 } 827 } 828 829 type failingRoundTripper struct{} 830 831 func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 832 return nil, errors.New("some error") 833 } 834 835 type staticResponseRoundTripper struct{ res *http.Response } 836 837 func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 838 return rt.res, nil 839 } 840 841 func TestReverseProxyErrorHandler(t *testing.T) { 842 tests := []struct { 843 name string 844 wantCode int 845 errorHandler func(http.ResponseWriter, *http.Request, error) 846 transport http.RoundTripper // defaults to failingRoundTripper 847 modifyResponse func(*http.Response) error 848 }{ 849 { 850 name: "default", 851 wantCode: http.StatusBadGateway, 852 }, 853 { 854 name: "errorhandler", 855 wantCode: http.StatusTeapot, 856 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 857 }, 858 { 859 name: "modifyresponse_noerr", 860 transport: staticResponseRoundTripper{ 861 &http.Response{StatusCode: 345, Body: http.NoBody}, 862 }, 863 modifyResponse: func(res *http.Response) error { 864 res.StatusCode++ 865 return nil 866 }, 867 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 868 wantCode: 346, 869 }, 870 { 871 name: "modifyresponse_err", 872 transport: staticResponseRoundTripper{ 873 &http.Response{StatusCode: 345, Body: http.NoBody}, 874 }, 875 modifyResponse: func(res *http.Response) error { 876 res.StatusCode++ 877 return errors.New("some error to trigger errorHandler") 878 }, 879 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 880 wantCode: http.StatusTeapot, 881 }, 882 } 883 884 for _, tt := range tests { 885 t.Run(tt.name, func(t *testing.T) { 886 target := &url.URL{ 887 Scheme: "http", 888 Host: "dummy.tld", 889 Path: "/", 890 } 891 rproxy := NewSingleHostReverseProxy(target) 892 rproxy.Transport = tt.transport 893 rproxy.ModifyResponse = tt.modifyResponse 894 if rproxy.Transport == nil { 895 rproxy.Transport = failingRoundTripper{} 896 } 897 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 898 if tt.errorHandler != nil { 899 rproxy.ErrorHandler = tt.errorHandler 900 } 901 frontendProxy := httptest.NewServer(rproxy) 902 defer frontendProxy.Close() 903 904 resp, err := http.Get(frontendProxy.URL + "/test") 905 if err != nil { 906 t.Fatalf("failed to reach proxy: %v", err) 907 } 908 if g, e := resp.StatusCode, tt.wantCode; g != e { 909 t.Errorf("got res.StatusCode %d; expected %d", g, e) 910 } 911 _ = resp.Body.Close() 912 }) 913 } 914 } 915 916 // Issue 16659: log errors from short read 917 func TestReverseProxy_CopyBuffer(t *testing.T) { 918 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 919 out := "this call was relayed by the reverse proxy" 920 // Coerce a wrong content length to induce io.UnexpectedEOF 921 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 922 _, _ = fmt.Fprintln(w, out) 923 })) 924 defer backendServer.Close() 925 926 rpURL, err := url.Parse(backendServer.URL) 927 if err != nil { 928 t.Fatal(err) 929 } 930 931 var proxyLog bytes.Buffer 932 rproxy := NewSingleHostReverseProxy(rpURL) 933 rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) 934 donec := make(chan bool, 1) 935 frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 936 defer func() { donec <- true }() 937 rproxy.ServeHTTP(w, r) 938 })) 939 defer frontendProxy.Close() 940 941 if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { 942 t.Fatalf("want non-nil error") 943 } 944 // The race detector complains about the proxyLog usage in logf in copyBuffer 945 // and our usage below with proxyLog.Bytes() so we're explicitly using a 946 // channel to ensure that the ReverseProxy's ServeHTTP is done before we 947 // continue after Get. 948 <-donec 949 950 expected := []string{ 951 "EOF", 952 "read", 953 } 954 for _, phrase := range expected { 955 if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { 956 t.Errorf("expected log to contain phrase %q", phrase) 957 } 958 } 959 } 960 961 type staticTransport struct { 962 res *http.Response 963 } 964 965 //goland:noinspection GoUnusedParameter 966 func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { 967 return t.res, nil 968 } 969 970 func BenchmarkServeHTTP(b *testing.B) { 971 res := &http.Response{ 972 StatusCode: 200, 973 Body: io.NopCloser(strings.NewReader("")), 974 } 975 proxy := &ReverseProxy{ 976 Director: func(*http.Request) {}, 977 Transport: &staticTransport{res}, 978 } 979 980 w := httptest.NewRecorder() 981 r := httptest.NewRequest("GET", "/", nil) 982 983 b.ReportAllocs() 984 for i := 0; i < b.N; i++ { 985 proxy.ServeHTTP(w, r) 986 } 987 } 988 989 func TestServeHTTPDeepCopy(t *testing.T) { 990 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 991 _, _ = w.Write([]byte("Hello Gopher!")) 992 })) 993 defer backend.Close() 994 backendURL, err := url.Parse(backend.URL) 995 if err != nil { 996 t.Fatal(err) 997 } 998 999 type result struct { 1000 before, after string 1001 } 1002 1003 resultChan := make(chan result, 1) 1004 proxyHandler := NewSingleHostReverseProxy(backendURL) 1005 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1006 before := r.URL.String() 1007 proxyHandler.ServeHTTP(w, r) 1008 after := r.URL.String() 1009 resultChan <- result{before: before, after: after} 1010 })) 1011 defer frontend.Close() 1012 1013 want := result{before: "/", after: "/"} 1014 1015 res, err := frontend.Client().Get(frontend.URL) 1016 if err != nil { 1017 t.Fatalf("Do: %v", err) 1018 } 1019 _ = res.Body.Close() 1020 1021 got := <-resultChan 1022 if got != want { 1023 t.Errorf("got = %+v; want = %+v", got, want) 1024 } 1025 } 1026 1027 // Issue 18327: verify we always do a deep copy of the Request.Header map 1028 // before any mutations. 1029 func TestClonesRequestHeaders(t *testing.T) { 1030 log.SetOutput(io.Discard) 1031 defer log.SetOutput(os.Stderr) 1032 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1033 req.RemoteAddr = "1.2.3.4:56789" 1034 rp := &ReverseProxy{ 1035 Director: func(req *http.Request) { 1036 req.Header.Set("From-Director", "1") 1037 }, 1038 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { 1039 if v := req.Header.Get("From-Director"); v != "1" { 1040 t.Errorf("From-Directory value = %q; want 1", v) 1041 } 1042 return nil, io.EOF 1043 }), 1044 } 1045 rp.ServeHTTP(httptest.NewRecorder(), req) 1046 1047 if req.Header.Get("From-Director") == "1" { 1048 t.Error("Director header mutation modified caller's request") 1049 } 1050 if req.Header.Get("X-Forwarded-For") != "" { 1051 t.Error("X-Forward-For header mutation modified caller's request") 1052 } 1053 1054 } 1055 1056 type roundTripperFunc func(req *http.Request) (*http.Response, error) 1057 1058 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 1059 return fn(req) 1060 } 1061 1062 func TestModifyResponseClosesBody(t *testing.T) { 1063 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1064 req.RemoteAddr = "1.2.3.4:56789" 1065 closeCheck := new(checkCloser) 1066 logBuf := new(bytes.Buffer) 1067 outErr := errors.New("ModifyResponse error") 1068 rp := &ReverseProxy{ 1069 Director: func(req *http.Request) {}, 1070 Transport: &staticTransport{&http.Response{ 1071 StatusCode: 200, 1072 Body: closeCheck, 1073 }}, 1074 ErrorLog: log.New(logBuf, "", 0), 1075 ModifyResponse: func(*http.Response) error { 1076 return outErr 1077 }, 1078 } 1079 rec := httptest.NewRecorder() 1080 rp.ServeHTTP(rec, req) 1081 res := rec.Result() 1082 if g, e := res.StatusCode, http.StatusBadGateway; g != e { 1083 t.Errorf("got res.StatusCode %d; expected %d", g, e) 1084 } 1085 if !closeCheck.closed { 1086 t.Errorf("body should have been closed") 1087 } 1088 if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { 1089 t.Errorf("ErrorLog %q does not contain %q", g, e) 1090 } 1091 } 1092 1093 type checkCloser struct { 1094 closed bool 1095 } 1096 1097 func (cc *checkCloser) Close() error { 1098 cc.closed = true 1099 return nil 1100 } 1101 1102 func (cc *checkCloser) Read(b []byte) (int, error) { 1103 return len(b), nil 1104 } 1105 1106 // Issue 23643: panic on body copy error 1107 func TestReverseProxy_PanicBodyError(t *testing.T) { 1108 log.SetOutput(io.Discard) 1109 defer log.SetOutput(os.Stderr) 1110 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1111 out := "this call was relayed by the reverse proxy" 1112 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1113 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1114 _, _ = fmt.Fprintln(w, out) 1115 })) 1116 defer backendServer.Close() 1117 1118 rpURL, err := url.Parse(backendServer.URL) 1119 if err != nil { 1120 t.Fatal(err) 1121 } 1122 1123 rproxy := NewSingleHostReverseProxy(rpURL) 1124 1125 // Ensure that the handler panics when the body read encounters an 1126 // io.ErrUnexpectedEOF 1127 defer func() { 1128 err := recover() 1129 if err == nil { 1130 t.Fatal("handler should have panicked") 1131 } 1132 if err != http.ErrAbortHandler { 1133 t.Fatal("expected ErrAbortHandler, got", err) 1134 } 1135 }() 1136 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1137 rproxy.ServeHTTP(httptest.NewRecorder(), req) 1138 } 1139 1140 // Issue #46866: panic without closing incoming request body causes a panic 1141 func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) { 1142 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1143 out := "this call was relayed by the reverse proxy" 1144 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1145 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1146 _, _ = fmt.Fprintln(w, out) 1147 })) 1148 defer backend.Close() 1149 backendURL, err := url.Parse(backend.URL) 1150 if err != nil { 1151 t.Fatal(err) 1152 } 1153 proxyHandler := NewSingleHostReverseProxy(backendURL) 1154 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1155 frontend := httptest.NewServer(proxyHandler) 1156 defer frontend.Close() 1157 frontendClient := frontend.Client() 1158 1159 var wg sync.WaitGroup 1160 for i := 0; i < 2; i++ { 1161 wg.Add(1) 1162 go func() { 1163 defer wg.Done() 1164 for j := 0; j < 10; j++ { 1165 const reqLen = 6 * 1024 * 1024 1166 req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) 1167 req.ContentLength = reqLen 1168 resp, _ := frontendClient.Transport.RoundTrip(req) 1169 if resp != nil { 1170 _, _ = io.Copy(io.Discard, resp.Body) 1171 _ = resp.Body.Close() 1172 } 1173 } 1174 }() 1175 } 1176 wg.Wait() 1177 } 1178 1179 func TestSelectFlushInterval(t *testing.T) { 1180 tests := []struct { 1181 name string 1182 p *ReverseProxy 1183 res *http.Response 1184 want time.Duration 1185 }{ 1186 { 1187 name: "default", 1188 res: &http.Response{}, 1189 p: &ReverseProxy{FlushInterval: 123}, 1190 want: 123, 1191 }, 1192 { 1193 name: "server-sent events overrides non-zero", 1194 res: &http.Response{ 1195 Header: http.Header{ 1196 "Content-Type": {"text/event-stream"}, 1197 }, 1198 }, 1199 p: &ReverseProxy{FlushInterval: 123}, 1200 want: -1, 1201 }, 1202 { 1203 name: "server-sent events overrides zero", 1204 res: &http.Response{ 1205 Header: http.Header{ 1206 "Content-Type": {"text/event-stream"}, 1207 }, 1208 }, 1209 p: &ReverseProxy{FlushInterval: 0}, 1210 want: -1, 1211 }, 1212 { 1213 name: "Content-Length: -1, overrides non-zero", 1214 res: &http.Response{ 1215 ContentLength: -1, 1216 }, 1217 p: &ReverseProxy{FlushInterval: 123}, 1218 want: -1, 1219 }, 1220 { 1221 name: "Content-Length: -1, overrides zero", 1222 res: &http.Response{ 1223 ContentLength: -1, 1224 }, 1225 p: &ReverseProxy{FlushInterval: 0}, 1226 want: -1, 1227 }, 1228 } 1229 for _, tt := range tests { 1230 t.Run(tt.name, func(t *testing.T) { 1231 got := tt.p.flushInterval(tt.res) 1232 if got != tt.want { 1233 t.Errorf("flushLatency = %v; want %v", got, tt.want) 1234 } 1235 }) 1236 } 1237 } 1238 1239 func TestReverseProxyWebSocket(t *testing.T) { 1240 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1241 if upgradeType(r.Header) != "websocket" { 1242 t.Error("unexpected backend request") 1243 http.Error(w, "unexpected request", 400) 1244 return 1245 } 1246 c, _, err := w.(http.Hijacker).Hijack() 1247 if err != nil { 1248 t.Error(err) 1249 return 1250 } 1251 defer func(c net.Conn) { 1252 _ = c.Close() 1253 }(c) 1254 _, _ = io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") 1255 bs := bufio.NewScanner(c) 1256 if !bs.Scan() { 1257 t.Errorf("backend failed to read line from client: %v", bs.Err()) 1258 return 1259 } 1260 _, _ = fmt.Fprintf(c, "backend got %q\n", bs.Text()) 1261 })) 1262 defer backendServer.Close() 1263 1264 backURL, _ := url.Parse(backendServer.URL) 1265 rproxy := NewSingleHostReverseProxy(backURL) 1266 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1267 rproxy.ModifyResponse = func(res *http.Response) error { 1268 res.Header.Add("X-Modified", "true") 1269 return nil 1270 } 1271 1272 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1273 rw.Header().Set("X-Header", "X-Value") 1274 rproxy.ServeHTTP(rw, req) 1275 if got, want := rw.Header().Get("X-Modified"), "true"; got != want { 1276 t.Errorf("response writer X-Modified header = %q; want %q", got, want) 1277 } 1278 }) 1279 1280 frontendProxy := httptest.NewServer(handler) 1281 defer frontendProxy.Close() 1282 1283 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1284 req.Header.Set("Connection", "Upgrade") 1285 req.Header.Set("Upgrade", "websocket") 1286 1287 c := frontendProxy.Client() 1288 res, err := c.Do(req) 1289 if err != nil { 1290 t.Fatal(err) 1291 } 1292 if res.StatusCode != 101 { 1293 t.Fatalf("status = %v; want 101", res.Status) 1294 } 1295 1296 got := res.Header.Get("X-Header") 1297 want := "X-Value" 1298 if got != want { 1299 t.Errorf("Header(XHeader) = %q; want %q", got, want) 1300 } 1301 1302 if !ascii.EqualFold(upgradeType(res.Header), "websocket") { 1303 t.Fatalf("not websocket upgrade; got %#v", res.Header) 1304 } 1305 rwc, ok := res.Body.(io.ReadWriteCloser) 1306 if !ok { 1307 t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) 1308 } 1309 defer func(rwc io.ReadWriteCloser) { 1310 _ = rwc.Close() 1311 }(rwc) 1312 1313 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1314 t.Errorf("response X-Modified header = %q; want %q", got, want) 1315 } 1316 1317 _, _ = io.WriteString(rwc, "Hello\n") 1318 bs := bufio.NewScanner(rwc) 1319 if !bs.Scan() { 1320 t.Fatalf("Scan: %v", bs.Err()) 1321 } 1322 got = bs.Text() 1323 want = `backend got "Hello"` 1324 if got != want { 1325 t.Errorf("got %#q, want %#q", got, want) 1326 } 1327 } 1328 1329 func TestReverseProxyWebSocketCancellation(t *testing.T) { 1330 n := 5 1331 triggerCancelCh := make(chan bool, n) 1332 nthResponse := func(i int) string { 1333 return fmt.Sprintf("backend response #%d\n", i) 1334 } 1335 terminalMsg := "final message" 1336 1337 cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1338 if g, ws := upgradeType(r.Header), "websocket"; g != ws { 1339 t.Errorf("Unexpected upgrade type %q, want %q", g, ws) 1340 http.Error(w, "Unexpected request", 400) 1341 return 1342 } 1343 conn, bufrw, err := w.(http.Hijacker).Hijack() 1344 if err != nil { 1345 t.Error(err) 1346 return 1347 } 1348 defer func(conn net.Conn) { 1349 _ = conn.Close() 1350 }(conn) 1351 1352 upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" 1353 if _, err := io.WriteString(conn, upgradeMsg); err != nil { 1354 t.Error(err) 1355 return 1356 } 1357 if _, _, err := bufrw.ReadLine(); err != nil { 1358 t.Errorf("Failed to read line from client: %v", err) 1359 return 1360 } 1361 1362 for i := 0; i < n; i++ { 1363 if _, err := bufrw.WriteString(nthResponse(i)); err != nil { 1364 select { 1365 case <-triggerCancelCh: 1366 default: 1367 t.Errorf("Writing response #%d failed: %v", i, err) 1368 } 1369 return 1370 } 1371 _ = bufrw.Flush() 1372 time.Sleep(time.Second) 1373 } 1374 if _, err := bufrw.WriteString(terminalMsg); err != nil { 1375 select { 1376 case <-triggerCancelCh: 1377 default: 1378 t.Errorf("Failed to write terminal message: %v", err) 1379 } 1380 } 1381 _ = bufrw.Flush() 1382 })) 1383 defer cst.Close() 1384 1385 backendURL, _ := url.Parse(cst.URL) 1386 rproxy := NewSingleHostReverseProxy(backendURL) 1387 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1388 rproxy.ModifyResponse = func(res *http.Response) error { 1389 res.Header.Add("X-Modified", "true") 1390 return nil 1391 } 1392 1393 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1394 rw.Header().Set("X-Header", "X-Value") 1395 ctx, cancel := context.WithCancel(req.Context()) 1396 go func() { 1397 <-triggerCancelCh 1398 cancel() 1399 }() 1400 rproxy.ServeHTTP(rw, req.WithContext(ctx)) 1401 }) 1402 1403 frontendProxy := httptest.NewServer(handler) 1404 defer frontendProxy.Close() 1405 1406 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1407 req.Header.Set("Connection", "Upgrade") 1408 req.Header.Set("Upgrade", "websocket") 1409 1410 res, err := frontendProxy.Client().Do(req) 1411 if err != nil { 1412 t.Fatalf("Dialing to frontend proxy: %v", err) 1413 } 1414 defer func(Body io.ReadCloser) { 1415 _ = Body.Close() 1416 }(res.Body) 1417 if g, w := res.StatusCode, 101; g != w { 1418 t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) 1419 } 1420 1421 if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { 1422 t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) 1423 } 1424 1425 if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) { 1426 t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) 1427 } 1428 1429 rwc, ok := res.Body.(io.ReadWriteCloser) 1430 if !ok { 1431 t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) 1432 } 1433 1434 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1435 t.Errorf("response X-Modified header = %q; want %q", got, want) 1436 } 1437 1438 if _, err := io.WriteString(rwc, "Hello\n"); err != nil { 1439 t.Fatalf("Failed to write first message: %v", err) 1440 } 1441 1442 // Read loop. 1443 1444 br := bufio.NewReader(rwc) 1445 for { 1446 line, err := br.ReadString('\n') 1447 switch { 1448 case line == terminalMsg: // this case before "err == io.EOF" 1449 t.Fatalf("The websocket request was not canceled, unfortunately!") 1450 1451 case err == io.EOF: 1452 return 1453 1454 case err != nil: 1455 t.Fatalf("Unexpected error: %v", err) 1456 1457 case line == nthResponse(0): // We've gotten the first response back 1458 // Let's trigger a cancel. 1459 close(triggerCancelCh) 1460 } 1461 } 1462 } 1463 1464 func TestUnannouncedTrailer(t *testing.T) { 1465 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1466 w.WriteHeader(http.StatusOK) 1467 w.(http.Flusher).Flush() 1468 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 1469 })) 1470 defer backend.Close() 1471 backendURL, err := url.Parse(backend.URL) 1472 if err != nil { 1473 t.Fatal(err) 1474 } 1475 proxyHandler := NewSingleHostReverseProxy(backendURL) 1476 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1477 frontend := httptest.NewServer(proxyHandler) 1478 defer frontend.Close() 1479 frontendClient := frontend.Client() 1480 1481 res, err := frontendClient.Get(frontend.URL) 1482 if err != nil { 1483 t.Fatalf("Get: %v", err) 1484 } 1485 1486 _, _ = io.ReadAll(res.Body) 1487 1488 if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { 1489 t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) 1490 } 1491 1492 } 1493 1494 func TestSingleJoinSlash(t *testing.T) { 1495 tests := []struct { 1496 slasha string 1497 slashb string 1498 expected string 1499 }{ 1500 {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1501 {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1502 {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, 1503 {"https://www.google.com", "", "https://www.google.com/"}, 1504 {"", "favicon.ico", "/favicon.ico"}, 1505 } 1506 for _, tt := range tests { 1507 if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { 1508 t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", 1509 tt.slasha, 1510 tt.slashb, 1511 tt.expected, 1512 got) 1513 } 1514 } 1515 } 1516 1517 func TestJoinURLPath(t *testing.T) { 1518 tests := []struct { 1519 a *url.URL 1520 b *url.URL 1521 wantPath string 1522 wantRaw string 1523 }{ 1524 {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, 1525 {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, 1526 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1527 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1528 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, 1529 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, 1530 } 1531 1532 for _, tt := range tests { 1533 p, rp := joinURLPath(tt.a, tt.b) 1534 if p != tt.wantPath || rp != tt.wantRaw { 1535 t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", 1536 tt.a.Path, tt.a.RawPath, 1537 tt.b.Path, tt.b.RawPath, 1538 tt.wantPath, tt.wantRaw, 1539 p, rp) 1540 } 1541 } 1542 }