github.com/ck00004/CobaltStrikeParser-Go@v1.0.14/lib/http/httputil/reverseproxy_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Reverse proxy tests. 6 7 package httputil 8 9 import ( 10 "bufio" 11 "bytes" 12 "context" 13 "errors" 14 "fmt" 15 "io" 16 "log" 17 "os" 18 "reflect" 19 "sort" 20 "strconv" 21 "strings" 22 "sync" 23 "testing" 24 "time" 25 26 "github.com/ck00004/CobaltStrikeParser-Go/lib/url" 27 28 "github.com/ck00004/CobaltStrikeParser-Go/lib/http/internal/ascii" 29 30 "github.com/ck00004/CobaltStrikeParser-Go/lib/http" 31 32 "github.com/ck00004/CobaltStrikeParser-Go/lib/http/httptest" 33 ) 34 35 const fakeHopHeader = "X-Fake-Hop-Header-For-Test" 36 37 func init() { 38 inOurTests = true 39 hopHeaders = append(hopHeaders, fakeHopHeader) 40 } 41 42 func TestReverseProxy(t *testing.T) { 43 const backendResponse = "I am the backend" 44 const backendStatus = 404 45 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 46 if r.Method == "GET" && r.FormValue("mode") == "hangup" { 47 c, _, _ := w.(http.Hijacker).Hijack() 48 c.Close() 49 return 50 } 51 if len(r.TransferEncoding) > 0 { 52 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) 53 } 54 if r.Header.Get("X-Forwarded-For") == "" { 55 t.Errorf("didn't get X-Forwarded-For header") 56 } 57 if c := r.Header.Get("Connection"); c != "" { 58 t.Errorf("handler got Connection header value %q", c) 59 } 60 if c := r.Header.Get("Te"); c != "trailers" { 61 t.Errorf("handler got Te header value %q; want 'trailers'", c) 62 } 63 if c := r.Header.Get("Upgrade"); c != "" { 64 t.Errorf("handler got Upgrade header value %q", c) 65 } 66 if c := r.Header.Get("Proxy-Connection"); c != "" { 67 t.Errorf("handler got Proxy-Connection header value %q", c) 68 } 69 if g, e := r.Host, "some-name"; g != e { 70 t.Errorf("backend got Host header %q, want %q", g, e) 71 } 72 w.Header().Set("Trailers", "not a special header field name") 73 w.Header().Set("Trailer", "X-Trailer") 74 w.Header().Set("X-Foo", "bar") 75 w.Header().Set("Upgrade", "foo") 76 w.Header().Set(fakeHopHeader, "foo") 77 w.Header().Add("X-Multi-Value", "foo") 78 w.Header().Add("X-Multi-Value", "bar") 79 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) 80 w.WriteHeader(backendStatus) 81 w.Write([]byte(backendResponse)) 82 w.Header().Set("X-Trailer", "trailer_value") 83 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 84 })) 85 defer backend.Close() 86 backendURL, err := url.Parse(backend.URL) 87 if err != nil { 88 t.Fatal(err) 89 } 90 proxyHandler := NewSingleHostReverseProxy(backendURL) 91 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 92 frontend := httptest.NewServer(proxyHandler) 93 defer frontend.Close() 94 frontendClient := frontend.Client() 95 96 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 97 getReq.Host = "some-name" 98 getReq.Header.Set("Connection", "close, TE") 99 getReq.Header.Add("Te", "foo") 100 getReq.Header.Add("Te", "bar, trailers") 101 getReq.Header.Set("Proxy-Connection", "should be deleted") 102 getReq.Header.Set("Upgrade", "foo") 103 getReq.Close = true 104 res, err := frontendClient.Do(getReq) 105 if err != nil { 106 t.Fatalf("Get: %v", err) 107 } 108 if g, e := res.StatusCode, backendStatus; g != e { 109 t.Errorf("got res.StatusCode %d; expected %d", g, e) 110 } 111 if g, e := res.Header.Get("X-Foo"), "bar"; g != e { 112 t.Errorf("got X-Foo %q; expected %q", g, e) 113 } 114 if c := res.Header.Get(fakeHopHeader); c != "" { 115 t.Errorf("got %s header value %q", fakeHopHeader, c) 116 } 117 if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { 118 t.Errorf("header Trailers = %q; want %q", g, e) 119 } 120 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { 121 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) 122 } 123 if g, e := len(res.Header["Set-Cookie"]), 1; g != e { 124 t.Fatalf("got %d SetCookies, want %d", g, e) 125 } 126 if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { 127 t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) 128 } 129 if cookie := res.Cookies()[0]; cookie.Name != "flavor" { 130 t.Errorf("unexpected cookie %q", cookie.Name) 131 } 132 bodyBytes, _ := io.ReadAll(res.Body) 133 if g, e := string(bodyBytes), backendResponse; g != e { 134 t.Errorf("got body %q; expected %q", g, e) 135 } 136 if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { 137 t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) 138 } 139 if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { 140 t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) 141 } 142 143 // Test that a backend failing to be reached or one which doesn't return 144 // a response results in a StatusBadGateway. 145 getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) 146 getReq.Close = true 147 res, err = frontendClient.Do(getReq) 148 if err != nil { 149 t.Fatal(err) 150 } 151 res.Body.Close() 152 if res.StatusCode != http.StatusBadGateway { 153 t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) 154 } 155 156 } 157 158 // Issue 16875: remove any proxied headers mentioned in the "Connection" 159 // header value. 160 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { 161 const fakeConnectionToken = "X-Fake-Connection-Token" 162 const backendResponse = "I am the backend" 163 164 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 165 // in the Request's Connection header. 166 const someConnHeader = "X-Some-Conn-Header" 167 168 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 169 if c := r.Header.Get("Connection"); c != "" { 170 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 171 } 172 if c := r.Header.Get(fakeConnectionToken); c != "" { 173 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 174 } 175 if c := r.Header.Get(someConnHeader); c != "" { 176 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 177 } 178 w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) 179 w.Header().Add("Connection", someConnHeader) 180 w.Header().Set(someConnHeader, "should be deleted") 181 w.Header().Set(fakeConnectionToken, "should be deleted") 182 io.WriteString(w, backendResponse) 183 })) 184 defer backend.Close() 185 backendURL, err := url.Parse(backend.URL) 186 if err != nil { 187 t.Fatal(err) 188 } 189 proxyHandler := NewSingleHostReverseProxy(backendURL) 190 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 191 proxyHandler.ServeHTTP(w, r) 192 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 193 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 194 } 195 if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { 196 t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") 197 } 198 c := r.Header["Connection"] 199 var cf []string 200 for _, f := range c { 201 for _, sf := range strings.Split(f, ",") { 202 if sf = strings.TrimSpace(sf); sf != "" { 203 cf = append(cf, sf) 204 } 205 } 206 } 207 sort.Strings(cf) 208 expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} 209 sort.Strings(expectedValues) 210 if !reflect.DeepEqual(cf, expectedValues) { 211 t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) 212 } 213 })) 214 defer frontend.Close() 215 216 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 217 getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) 218 getReq.Header.Add("Connection", someConnHeader) 219 getReq.Header.Set(someConnHeader, "should be deleted") 220 getReq.Header.Set(fakeConnectionToken, "should be deleted") 221 res, err := frontend.Client().Do(getReq) 222 if err != nil { 223 t.Fatalf("Get: %v", err) 224 } 225 defer res.Body.Close() 226 bodyBytes, err := io.ReadAll(res.Body) 227 if err != nil { 228 t.Fatalf("reading body: %v", err) 229 } 230 if got, want := string(bodyBytes), backendResponse; got != want { 231 t.Errorf("got body %q; want %q", got, want) 232 } 233 if c := res.Header.Get("Connection"); c != "" { 234 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 235 } 236 if c := res.Header.Get(someConnHeader); c != "" { 237 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 238 } 239 if c := res.Header.Get(fakeConnectionToken); c != "" { 240 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 241 } 242 } 243 244 func TestReverseProxyStripEmptyConnection(t *testing.T) { 245 // See Issue 46313. 246 const backendResponse = "I am the backend" 247 248 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 249 // in the Request's Connection header. 250 const someConnHeader = "X-Some-Conn-Header" 251 252 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 253 if c := r.Header.Values("Connection"); len(c) != 0 { 254 t.Errorf("handler got header %q = %v; want empty", "Connection", c) 255 } 256 if c := r.Header.Get(someConnHeader); c != "" { 257 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 258 } 259 w.Header().Add("Connection", "") 260 w.Header().Add("Connection", someConnHeader) 261 w.Header().Set(someConnHeader, "should be deleted") 262 io.WriteString(w, backendResponse) 263 })) 264 defer backend.Close() 265 backendURL, err := url.Parse(backend.URL) 266 if err != nil { 267 t.Fatal(err) 268 } 269 proxyHandler := NewSingleHostReverseProxy(backendURL) 270 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 271 proxyHandler.ServeHTTP(w, r) 272 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 273 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 274 } 275 })) 276 defer frontend.Close() 277 278 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 279 getReq.Header.Add("Connection", "") 280 getReq.Header.Add("Connection", someConnHeader) 281 getReq.Header.Set(someConnHeader, "should be deleted") 282 res, err := frontend.Client().Do(getReq) 283 if err != nil { 284 t.Fatalf("Get: %v", err) 285 } 286 defer res.Body.Close() 287 bodyBytes, err := io.ReadAll(res.Body) 288 if err != nil { 289 t.Fatalf("reading body: %v", err) 290 } 291 if got, want := string(bodyBytes), backendResponse; got != want { 292 t.Errorf("got body %q; want %q", got, want) 293 } 294 if c := res.Header.Get("Connection"); c != "" { 295 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 296 } 297 if c := res.Header.Get(someConnHeader); c != "" { 298 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 299 } 300 } 301 302 func TestXForwardedFor(t *testing.T) { 303 const prevForwardedFor = "client ip" 304 const backendResponse = "I am the backend" 305 const backendStatus = 404 306 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 307 if r.Header.Get("X-Forwarded-For") == "" { 308 t.Errorf("didn't get X-Forwarded-For header") 309 } 310 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { 311 t.Errorf("X-Forwarded-For didn't contain prior data") 312 } 313 w.WriteHeader(backendStatus) 314 w.Write([]byte(backendResponse)) 315 })) 316 defer backend.Close() 317 backendURL, err := url.Parse(backend.URL) 318 if err != nil { 319 t.Fatal(err) 320 } 321 proxyHandler := NewSingleHostReverseProxy(backendURL) 322 frontend := httptest.NewServer(proxyHandler) 323 defer frontend.Close() 324 325 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 326 getReq.Host = "some-name" 327 getReq.Header.Set("Connection", "close") 328 getReq.Header.Set("X-Forwarded-For", prevForwardedFor) 329 getReq.Close = true 330 res, err := frontend.Client().Do(getReq) 331 if err != nil { 332 t.Fatalf("Get: %v", err) 333 } 334 if g, e := res.StatusCode, backendStatus; g != e { 335 t.Errorf("got res.StatusCode %d; expected %d", g, e) 336 } 337 bodyBytes, _ := io.ReadAll(res.Body) 338 if g, e := string(bodyBytes), backendResponse; g != e { 339 t.Errorf("got body %q; expected %q", g, e) 340 } 341 } 342 343 // Issue 38079: don't append to X-Forwarded-For if it's present but nil 344 func TestXForwardedFor_Omit(t *testing.T) { 345 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 346 if v := r.Header.Get("X-Forwarded-For"); v != "" { 347 t.Errorf("got X-Forwarded-For header: %q", v) 348 } 349 w.Write([]byte("hi")) 350 })) 351 defer backend.Close() 352 backendURL, err := url.Parse(backend.URL) 353 if err != nil { 354 t.Fatal(err) 355 } 356 proxyHandler := NewSingleHostReverseProxy(backendURL) 357 frontend := httptest.NewServer(proxyHandler) 358 defer frontend.Close() 359 360 oldDirector := proxyHandler.Director 361 proxyHandler.Director = func(r *http.Request) { 362 r.Header["X-Forwarded-For"] = nil 363 oldDirector(r) 364 } 365 366 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 367 getReq.Host = "some-name" 368 getReq.Close = true 369 res, err := frontend.Client().Do(getReq) 370 if err != nil { 371 t.Fatalf("Get: %v", err) 372 } 373 res.Body.Close() 374 } 375 376 var proxyQueryTests = []struct { 377 baseSuffix string // suffix to add to backend URL 378 reqSuffix string // suffix to add to frontend's request URL 379 want string // what backend should see for final request URL (without ?) 380 }{ 381 {"", "", ""}, 382 {"?sta=tic", "?us=er", "sta=tic&us=er"}, 383 {"", "?us=er", "us=er"}, 384 {"?sta=tic", "", "sta=tic"}, 385 } 386 387 func TestReverseProxyQuery(t *testing.T) { 388 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 389 w.Header().Set("X-Got-Query", r.URL.RawQuery) 390 w.Write([]byte("hi")) 391 })) 392 defer backend.Close() 393 394 for i, tt := range proxyQueryTests { 395 backendURL, err := url.Parse(backend.URL + tt.baseSuffix) 396 if err != nil { 397 t.Fatal(err) 398 } 399 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) 400 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) 401 req.Close = true 402 res, err := frontend.Client().Do(req) 403 if err != nil { 404 t.Fatalf("%d. Get: %v", i, err) 405 } 406 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { 407 t.Errorf("%d. got query %q; expected %q", i, g, e) 408 } 409 res.Body.Close() 410 frontend.Close() 411 } 412 } 413 414 func TestReverseProxyFlushInterval(t *testing.T) { 415 const expected = "hi" 416 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 417 w.Write([]byte(expected)) 418 })) 419 defer backend.Close() 420 421 backendURL, err := url.Parse(backend.URL) 422 if err != nil { 423 t.Fatal(err) 424 } 425 426 proxyHandler := NewSingleHostReverseProxy(backendURL) 427 proxyHandler.FlushInterval = time.Microsecond 428 429 frontend := httptest.NewServer(proxyHandler) 430 defer frontend.Close() 431 432 req, _ := http.NewRequest("GET", frontend.URL, nil) 433 req.Close = true 434 res, err := frontend.Client().Do(req) 435 if err != nil { 436 t.Fatalf("Get: %v", err) 437 } 438 defer res.Body.Close() 439 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { 440 t.Errorf("got body %q; expected %q", bodyBytes, expected) 441 } 442 } 443 444 func TestReverseProxyFlushIntervalHeaders(t *testing.T) { 445 const expected = "hi" 446 stopCh := make(chan struct{}) 447 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 448 w.Header().Add("MyHeader", expected) 449 w.WriteHeader(200) 450 w.(http.Flusher).Flush() 451 <-stopCh 452 })) 453 defer backend.Close() 454 defer close(stopCh) 455 456 backendURL, err := url.Parse(backend.URL) 457 if err != nil { 458 t.Fatal(err) 459 } 460 461 proxyHandler := NewSingleHostReverseProxy(backendURL) 462 proxyHandler.FlushInterval = time.Microsecond 463 464 frontend := httptest.NewServer(proxyHandler) 465 defer frontend.Close() 466 467 req, _ := http.NewRequest("GET", frontend.URL, nil) 468 req.Close = true 469 470 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) 471 defer cancel() 472 req = req.WithContext(ctx) 473 474 res, err := frontend.Client().Do(req) 475 if err != nil { 476 t.Fatalf("Get: %v", err) 477 } 478 defer res.Body.Close() 479 480 if res.Header.Get("MyHeader") != expected { 481 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) 482 } 483 } 484 485 func TestReverseProxyCancellation(t *testing.T) { 486 const backendResponse = "I am the backend" 487 488 reqInFlight := make(chan struct{}) 489 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 490 close(reqInFlight) // cause the client to cancel its request 491 492 select { 493 case <-time.After(10 * time.Second): 494 // Note: this should only happen in broken implementations, and the 495 // closenotify case should be instantaneous. 496 t.Error("Handler never saw CloseNotify") 497 return 498 case <-w.(http.CloseNotifier).CloseNotify(): 499 } 500 501 w.WriteHeader(http.StatusOK) 502 w.Write([]byte(backendResponse)) 503 })) 504 505 defer backend.Close() 506 507 backend.Config.ErrorLog = log.New(io.Discard, "", 0) 508 509 backendURL, err := url.Parse(backend.URL) 510 if err != nil { 511 t.Fatal(err) 512 } 513 514 proxyHandler := NewSingleHostReverseProxy(backendURL) 515 516 // Discards errors of the form: 517 // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection 518 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) 519 520 frontend := httptest.NewServer(proxyHandler) 521 defer frontend.Close() 522 frontendClient := frontend.Client() 523 524 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 525 go func() { 526 <-reqInFlight 527 frontendClient.Transport.(*http.Transport).CancelRequest(getReq) 528 }() 529 res, err := frontendClient.Do(getReq) 530 if res != nil { 531 t.Errorf("got response %v; want nil", res.Status) 532 } 533 if err == nil { 534 // This should be an error like: 535 // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079: 536 // use of closed network connection 537 t.Error("Server.Client().Do() returned nil error; want non-nil error") 538 } 539 } 540 541 func req(t *testing.T, v string) *http.Request { 542 req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) 543 if err != nil { 544 t.Fatal(err) 545 } 546 return req 547 } 548 549 // Issue 12344 550 func TestNilBody(t *testing.T) { 551 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 552 w.Write([]byte("hi")) 553 })) 554 defer backend.Close() 555 556 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 557 backURL, _ := url.Parse(backend.URL) 558 rp := NewSingleHostReverseProxy(backURL) 559 r := req(t, "GET / HTTP/1.0\r\n\r\n") 560 r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working 561 rp.ServeHTTP(w, r) 562 })) 563 defer frontend.Close() 564 565 res, err := http.Get(frontend.URL) 566 if err != nil { 567 t.Fatal(err) 568 } 569 defer res.Body.Close() 570 slurp, err := io.ReadAll(res.Body) 571 if err != nil { 572 t.Fatal(err) 573 } 574 if string(slurp) != "hi" { 575 t.Errorf("Got %q; want %q", slurp, "hi") 576 } 577 } 578 579 // Issue 15524 580 func TestUserAgentHeader(t *testing.T) { 581 const explicitUA = "explicit UA" 582 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 583 if r.URL.Path == "/noua" { 584 if c := r.Header.Get("User-Agent"); c != "" { 585 t.Errorf("handler got non-empty User-Agent header %q", c) 586 } 587 return 588 } 589 if c := r.Header.Get("User-Agent"); c != explicitUA { 590 t.Errorf("handler got unexpected User-Agent header %q", c) 591 } 592 })) 593 defer backend.Close() 594 backendURL, err := url.Parse(backend.URL) 595 if err != nil { 596 t.Fatal(err) 597 } 598 proxyHandler := NewSingleHostReverseProxy(backendURL) 599 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 600 frontend := httptest.NewServer(proxyHandler) 601 defer frontend.Close() 602 frontendClient := frontend.Client() 603 604 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 605 getReq.Header.Set("User-Agent", explicitUA) 606 getReq.Close = true 607 res, err := frontendClient.Do(getReq) 608 if err != nil { 609 t.Fatalf("Get: %v", err) 610 } 611 res.Body.Close() 612 613 getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) 614 getReq.Header.Set("User-Agent", "") 615 getReq.Close = true 616 res, err = frontendClient.Do(getReq) 617 if err != nil { 618 t.Fatalf("Get: %v", err) 619 } 620 res.Body.Close() 621 } 622 623 type bufferPool struct { 624 get func() []byte 625 put func([]byte) 626 } 627 628 func (bp bufferPool) Get() []byte { return bp.get() } 629 func (bp bufferPool) Put(v []byte) { bp.put(v) } 630 631 func TestReverseProxyGetPutBuffer(t *testing.T) { 632 const msg = "hi" 633 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 634 io.WriteString(w, msg) 635 })) 636 defer backend.Close() 637 638 backendURL, err := url.Parse(backend.URL) 639 if err != nil { 640 t.Fatal(err) 641 } 642 643 var ( 644 mu sync.Mutex 645 log []string 646 ) 647 addLog := func(event string) { 648 mu.Lock() 649 defer mu.Unlock() 650 log = append(log, event) 651 } 652 rp := NewSingleHostReverseProxy(backendURL) 653 const size = 1234 654 rp.BufferPool = bufferPool{ 655 get: func() []byte { 656 addLog("getBuf") 657 return make([]byte, size) 658 }, 659 put: func(p []byte) { 660 addLog("putBuf-" + strconv.Itoa(len(p))) 661 }, 662 } 663 frontend := httptest.NewServer(rp) 664 defer frontend.Close() 665 666 req, _ := http.NewRequest("GET", frontend.URL, nil) 667 req.Close = true 668 res, err := frontend.Client().Do(req) 669 if err != nil { 670 t.Fatalf("Get: %v", err) 671 } 672 slurp, err := io.ReadAll(res.Body) 673 res.Body.Close() 674 if err != nil { 675 t.Fatalf("reading body: %v", err) 676 } 677 if string(slurp) != msg { 678 t.Errorf("msg = %q; want %q", slurp, msg) 679 } 680 wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} 681 mu.Lock() 682 defer mu.Unlock() 683 if !reflect.DeepEqual(log, wantLog) { 684 t.Errorf("Log events = %q; want %q", log, wantLog) 685 } 686 } 687 688 func TestReverseProxy_Post(t *testing.T) { 689 const backendResponse = "I am the backend" 690 const backendStatus = 200 691 var requestBody = bytes.Repeat([]byte("a"), 1<<20) 692 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 693 slurp, err := io.ReadAll(r.Body) 694 if err != nil { 695 t.Errorf("Backend body read = %v", err) 696 } 697 if len(slurp) != len(requestBody) { 698 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) 699 } 700 if !bytes.Equal(slurp, requestBody) { 701 t.Error("Backend read wrong request body.") // 1MB; omitting details 702 } 703 w.Write([]byte(backendResponse)) 704 })) 705 defer backend.Close() 706 backendURL, err := url.Parse(backend.URL) 707 if err != nil { 708 t.Fatal(err) 709 } 710 proxyHandler := NewSingleHostReverseProxy(backendURL) 711 frontend := httptest.NewServer(proxyHandler) 712 defer frontend.Close() 713 714 postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) 715 res, err := frontend.Client().Do(postReq) 716 if err != nil { 717 t.Fatalf("Do: %v", err) 718 } 719 if g, e := res.StatusCode, backendStatus; g != e { 720 t.Errorf("got res.StatusCode %d; expected %d", g, e) 721 } 722 bodyBytes, _ := io.ReadAll(res.Body) 723 if g, e := string(bodyBytes), backendResponse; g != e { 724 t.Errorf("got body %q; expected %q", g, e) 725 } 726 } 727 728 type RoundTripperFunc func(*http.Request) (*http.Response, error) 729 730 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 731 return fn(req) 732 } 733 734 // Issue 16036: send a Request with a nil Body when possible 735 func TestReverseProxy_NilBody(t *testing.T) { 736 backendURL, _ := url.Parse("http://fake.tld/") 737 proxyHandler := NewSingleHostReverseProxy(backendURL) 738 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 739 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 740 if req.Body != nil { 741 t.Error("Body != nil; want a nil Body") 742 } 743 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 744 }) 745 frontend := httptest.NewServer(proxyHandler) 746 defer frontend.Close() 747 748 res, err := frontend.Client().Get(frontend.URL) 749 if err != nil { 750 t.Fatal(err) 751 } 752 defer res.Body.Close() 753 if res.StatusCode != 502 { 754 t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) 755 } 756 } 757 758 // Issue 33142: always allocate the request headers 759 func TestReverseProxy_AllocatedHeader(t *testing.T) { 760 proxyHandler := new(ReverseProxy) 761 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 762 proxyHandler.Director = func(*http.Request) {} // noop 763 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 764 if req.Header == nil { 765 t.Error("Header == nil; want a non-nil Header") 766 } 767 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 768 }) 769 770 proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ 771 Method: "GET", 772 URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, 773 Proto: "HTTP/1.0", 774 ProtoMajor: 1, 775 }) 776 } 777 778 // Issue 14237. Test ModifyResponse and that an error from it 779 // causes the proxy to return StatusBadGateway, or StatusOK otherwise. 780 func TestReverseProxyModifyResponse(t *testing.T) { 781 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 782 w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) 783 })) 784 defer backendServer.Close() 785 786 rpURL, _ := url.Parse(backendServer.URL) 787 rproxy := NewSingleHostReverseProxy(rpURL) 788 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 789 rproxy.ModifyResponse = func(resp *http.Response) error { 790 if resp.Header.Get("X-Hit-Mod") != "true" { 791 return fmt.Errorf("tried to by-pass proxy") 792 } 793 return nil 794 } 795 796 frontendProxy := httptest.NewServer(rproxy) 797 defer frontendProxy.Close() 798 799 tests := []struct { 800 url string 801 wantCode int 802 }{ 803 {frontendProxy.URL + "/mod", http.StatusOK}, 804 {frontendProxy.URL + "/schedule", http.StatusBadGateway}, 805 } 806 807 for i, tt := range tests { 808 resp, err := http.Get(tt.url) 809 if err != nil { 810 t.Fatalf("failed to reach proxy: %v", err) 811 } 812 if g, e := resp.StatusCode, tt.wantCode; g != e { 813 t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) 814 } 815 resp.Body.Close() 816 } 817 } 818 819 type failingRoundTripper struct{} 820 821 func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 822 return nil, errors.New("some error") 823 } 824 825 type staticResponseRoundTripper struct{ res *http.Response } 826 827 func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 828 return rt.res, nil 829 } 830 831 func TestReverseProxyErrorHandler(t *testing.T) { 832 tests := []struct { 833 name string 834 wantCode int 835 errorHandler func(http.ResponseWriter, *http.Request, error) 836 transport http.RoundTripper // defaults to failingRoundTripper 837 modifyResponse func(*http.Response) error 838 }{ 839 { 840 name: "default", 841 wantCode: http.StatusBadGateway, 842 }, 843 { 844 name: "errorhandler", 845 wantCode: http.StatusTeapot, 846 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 847 }, 848 { 849 name: "modifyresponse_noerr", 850 transport: staticResponseRoundTripper{ 851 &http.Response{StatusCode: 345, Body: http.NoBody}, 852 }, 853 modifyResponse: func(res *http.Response) error { 854 res.StatusCode++ 855 return nil 856 }, 857 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 858 wantCode: 346, 859 }, 860 { 861 name: "modifyresponse_err", 862 transport: staticResponseRoundTripper{ 863 &http.Response{StatusCode: 345, Body: http.NoBody}, 864 }, 865 modifyResponse: func(res *http.Response) error { 866 res.StatusCode++ 867 return errors.New("some error to trigger errorHandler") 868 }, 869 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 870 wantCode: http.StatusTeapot, 871 }, 872 } 873 874 for _, tt := range tests { 875 t.Run(tt.name, func(t *testing.T) { 876 target := &url.URL{ 877 Scheme: "http", 878 Host: "dummy.tld", 879 Path: "/", 880 } 881 rproxy := NewSingleHostReverseProxy(target) 882 rproxy.Transport = tt.transport 883 rproxy.ModifyResponse = tt.modifyResponse 884 if rproxy.Transport == nil { 885 rproxy.Transport = failingRoundTripper{} 886 } 887 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 888 if tt.errorHandler != nil { 889 rproxy.ErrorHandler = tt.errorHandler 890 } 891 frontendProxy := httptest.NewServer(rproxy) 892 defer frontendProxy.Close() 893 894 resp, err := http.Get(frontendProxy.URL + "/test") 895 if err != nil { 896 t.Fatalf("failed to reach proxy: %v", err) 897 } 898 if g, e := resp.StatusCode, tt.wantCode; g != e { 899 t.Errorf("got res.StatusCode %d; expected %d", g, e) 900 } 901 resp.Body.Close() 902 }) 903 } 904 } 905 906 // Issue 16659: log errors from short read 907 func TestReverseProxy_CopyBuffer(t *testing.T) { 908 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 909 out := "this call was relayed by the reverse proxy" 910 // Coerce a wrong content length to induce io.UnexpectedEOF 911 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 912 fmt.Fprintln(w, out) 913 })) 914 defer backendServer.Close() 915 916 rpURL, err := url.Parse(backendServer.URL) 917 if err != nil { 918 t.Fatal(err) 919 } 920 921 var proxyLog bytes.Buffer 922 rproxy := NewSingleHostReverseProxy(rpURL) 923 rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) 924 donec := make(chan bool, 1) 925 frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 926 defer func() { donec <- true }() 927 rproxy.ServeHTTP(w, r) 928 })) 929 defer frontendProxy.Close() 930 931 if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { 932 t.Fatalf("want non-nil error") 933 } 934 // The race detector complains about the proxyLog usage in logf in copyBuffer 935 // and our usage below with proxyLog.Bytes() so we're explicitly using a 936 // channel to ensure that the ReverseProxy's ServeHTTP is done before we 937 // continue after Get. 938 <-donec 939 940 expected := []string{ 941 "EOF", 942 "read", 943 } 944 for _, phrase := range expected { 945 if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { 946 t.Errorf("expected log to contain phrase %q", phrase) 947 } 948 } 949 } 950 951 type staticTransport struct { 952 res *http.Response 953 } 954 955 func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { 956 return t.res, nil 957 } 958 959 func BenchmarkServeHTTP(b *testing.B) { 960 res := &http.Response{ 961 StatusCode: 200, 962 Body: io.NopCloser(strings.NewReader("")), 963 } 964 proxy := &ReverseProxy{ 965 Director: func(*http.Request) {}, 966 Transport: &staticTransport{res}, 967 } 968 969 w := httptest.NewRecorder() 970 r := httptest.NewRequest("GET", "/", nil) 971 972 b.ReportAllocs() 973 for i := 0; i < b.N; i++ { 974 proxy.ServeHTTP(w, r) 975 } 976 } 977 978 func TestServeHTTPDeepCopy(t *testing.T) { 979 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 980 w.Write([]byte("Hello Gopher!")) 981 })) 982 defer backend.Close() 983 backendURL, err := url.Parse(backend.URL) 984 if err != nil { 985 t.Fatal(err) 986 } 987 988 type result struct { 989 before, after string 990 } 991 992 resultChan := make(chan result, 1) 993 proxyHandler := NewSingleHostReverseProxy(backendURL) 994 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 995 before := r.URL.String() 996 proxyHandler.ServeHTTP(w, r) 997 after := r.URL.String() 998 resultChan <- result{before: before, after: after} 999 })) 1000 defer frontend.Close() 1001 1002 want := result{before: "/", after: "/"} 1003 1004 res, err := frontend.Client().Get(frontend.URL) 1005 if err != nil { 1006 t.Fatalf("Do: %v", err) 1007 } 1008 res.Body.Close() 1009 1010 got := <-resultChan 1011 if got != want { 1012 t.Errorf("got = %+v; want = %+v", got, want) 1013 } 1014 } 1015 1016 // Issue 18327: verify we always do a deep copy of the Request.Header map 1017 // before any mutations. 1018 func TestClonesRequestHeaders(t *testing.T) { 1019 log.SetOutput(io.Discard) 1020 defer log.SetOutput(os.Stderr) 1021 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1022 req.RemoteAddr = "1.2.3.4:56789" 1023 rp := &ReverseProxy{ 1024 Director: func(req *http.Request) { 1025 req.Header.Set("From-Director", "1") 1026 }, 1027 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { 1028 if v := req.Header.Get("From-Director"); v != "1" { 1029 t.Errorf("From-Directory value = %q; want 1", v) 1030 } 1031 return nil, io.EOF 1032 }), 1033 } 1034 rp.ServeHTTP(httptest.NewRecorder(), req) 1035 1036 if req.Header.Get("From-Director") == "1" { 1037 t.Error("Director header mutation modified caller's request") 1038 } 1039 if req.Header.Get("X-Forwarded-For") != "" { 1040 t.Error("X-Forward-For header mutation modified caller's request") 1041 } 1042 1043 } 1044 1045 type roundTripperFunc func(req *http.Request) (*http.Response, error) 1046 1047 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 1048 return fn(req) 1049 } 1050 1051 func TestModifyResponseClosesBody(t *testing.T) { 1052 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1053 req.RemoteAddr = "1.2.3.4:56789" 1054 closeCheck := new(checkCloser) 1055 logBuf := new(bytes.Buffer) 1056 outErr := errors.New("ModifyResponse error") 1057 rp := &ReverseProxy{ 1058 Director: func(req *http.Request) {}, 1059 Transport: &staticTransport{&http.Response{ 1060 StatusCode: 200, 1061 Body: closeCheck, 1062 }}, 1063 ErrorLog: log.New(logBuf, "", 0), 1064 ModifyResponse: func(*http.Response) error { 1065 return outErr 1066 }, 1067 } 1068 rec := httptest.NewRecorder() 1069 rp.ServeHTTP(rec, req) 1070 res := rec.Result() 1071 if g, e := res.StatusCode, http.StatusBadGateway; g != e { 1072 t.Errorf("got res.StatusCode %d; expected %d", g, e) 1073 } 1074 if !closeCheck.closed { 1075 t.Errorf("body should have been closed") 1076 } 1077 if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { 1078 t.Errorf("ErrorLog %q does not contain %q", g, e) 1079 } 1080 } 1081 1082 type checkCloser struct { 1083 closed bool 1084 } 1085 1086 func (cc *checkCloser) Close() error { 1087 cc.closed = true 1088 return nil 1089 } 1090 1091 func (cc *checkCloser) Read(b []byte) (int, error) { 1092 return len(b), nil 1093 } 1094 1095 // Issue 23643: panic on body copy error 1096 func TestReverseProxy_PanicBodyError(t *testing.T) { 1097 log.SetOutput(io.Discard) 1098 defer log.SetOutput(os.Stderr) 1099 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1100 out := "this call was relayed by the reverse proxy" 1101 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1102 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1103 fmt.Fprintln(w, out) 1104 })) 1105 defer backendServer.Close() 1106 1107 rpURL, err := url.Parse(backendServer.URL) 1108 if err != nil { 1109 t.Fatal(err) 1110 } 1111 1112 rproxy := NewSingleHostReverseProxy(rpURL) 1113 1114 // Ensure that the handler panics when the body read encounters an 1115 // io.ErrUnexpectedEOF 1116 defer func() { 1117 err := recover() 1118 if err == nil { 1119 t.Fatal("handler should have panicked") 1120 } 1121 if err != http.ErrAbortHandler { 1122 t.Fatal("expected ErrAbortHandler, got", err) 1123 } 1124 }() 1125 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1126 rproxy.ServeHTTP(httptest.NewRecorder(), req) 1127 } 1128 1129 // Issue #46866: panic without closing incoming request body causes a panic 1130 func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) { 1131 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1132 out := "this call was relayed by the reverse proxy" 1133 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1134 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1135 fmt.Fprintln(w, out) 1136 })) 1137 defer backend.Close() 1138 backendURL, err := url.Parse(backend.URL) 1139 if err != nil { 1140 t.Fatal(err) 1141 } 1142 proxyHandler := NewSingleHostReverseProxy(backendURL) 1143 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1144 frontend := httptest.NewServer(proxyHandler) 1145 defer frontend.Close() 1146 frontendClient := frontend.Client() 1147 1148 var wg sync.WaitGroup 1149 for i := 0; i < 2; i++ { 1150 wg.Add(1) 1151 go func() { 1152 defer wg.Done() 1153 for j := 0; j < 10; j++ { 1154 const reqLen = 6 * 1024 * 1024 1155 req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) 1156 req.ContentLength = reqLen 1157 resp, _ := frontendClient.Transport.RoundTrip(req) 1158 if resp != nil { 1159 io.Copy(io.Discard, resp.Body) 1160 resp.Body.Close() 1161 } 1162 } 1163 }() 1164 } 1165 wg.Wait() 1166 } 1167 1168 func TestSelectFlushInterval(t *testing.T) { 1169 tests := []struct { 1170 name string 1171 p *ReverseProxy 1172 res *http.Response 1173 want time.Duration 1174 }{ 1175 { 1176 name: "default", 1177 res: &http.Response{}, 1178 p: &ReverseProxy{FlushInterval: 123}, 1179 want: 123, 1180 }, 1181 { 1182 name: "server-sent events overrides non-zero", 1183 res: &http.Response{ 1184 Header: http.Header{ 1185 "Content-Type": {"text/event-stream"}, 1186 }, 1187 }, 1188 p: &ReverseProxy{FlushInterval: 123}, 1189 want: -1, 1190 }, 1191 { 1192 name: "server-sent events overrides zero", 1193 res: &http.Response{ 1194 Header: http.Header{ 1195 "Content-Type": {"text/event-stream"}, 1196 }, 1197 }, 1198 p: &ReverseProxy{FlushInterval: 0}, 1199 want: -1, 1200 }, 1201 { 1202 name: "Content-Length: -1, overrides non-zero", 1203 res: &http.Response{ 1204 ContentLength: -1, 1205 }, 1206 p: &ReverseProxy{FlushInterval: 123}, 1207 want: -1, 1208 }, 1209 { 1210 name: "Content-Length: -1, overrides zero", 1211 res: &http.Response{ 1212 ContentLength: -1, 1213 }, 1214 p: &ReverseProxy{FlushInterval: 0}, 1215 want: -1, 1216 }, 1217 } 1218 for _, tt := range tests { 1219 t.Run(tt.name, func(t *testing.T) { 1220 got := tt.p.flushInterval(tt.res) 1221 if got != tt.want { 1222 t.Errorf("flushLatency = %v; want %v", got, tt.want) 1223 } 1224 }) 1225 } 1226 } 1227 1228 func TestReverseProxyWebSocket(t *testing.T) { 1229 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1230 if upgradeType(r.Header) != "websocket" { 1231 t.Error("unexpected backend request") 1232 http.Error(w, "unexpected request", 400) 1233 return 1234 } 1235 c, _, err := w.(http.Hijacker).Hijack() 1236 if err != nil { 1237 t.Error(err) 1238 return 1239 } 1240 defer c.Close() 1241 io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") 1242 bs := bufio.NewScanner(c) 1243 if !bs.Scan() { 1244 t.Errorf("backend failed to read line from client: %v", bs.Err()) 1245 return 1246 } 1247 fmt.Fprintf(c, "backend got %q\n", bs.Text()) 1248 })) 1249 defer backendServer.Close() 1250 1251 backURL, _ := url.Parse(backendServer.URL) 1252 rproxy := NewSingleHostReverseProxy(backURL) 1253 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1254 rproxy.ModifyResponse = func(res *http.Response) error { 1255 res.Header.Add("X-Modified", "true") 1256 return nil 1257 } 1258 1259 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1260 rw.Header().Set("X-Header", "X-Value") 1261 rproxy.ServeHTTP(rw, req) 1262 if got, want := rw.Header().Get("X-Modified"), "true"; got != want { 1263 t.Errorf("response writer X-Modified header = %q; want %q", got, want) 1264 } 1265 }) 1266 1267 frontendProxy := httptest.NewServer(handler) 1268 defer frontendProxy.Close() 1269 1270 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1271 req.Header.Set("Connection", "Upgrade") 1272 req.Header.Set("Upgrade", "websocket") 1273 1274 c := frontendProxy.Client() 1275 res, err := c.Do(req) 1276 if err != nil { 1277 t.Fatal(err) 1278 } 1279 if res.StatusCode != 101 { 1280 t.Fatalf("status = %v; want 101", res.Status) 1281 } 1282 1283 got := res.Header.Get("X-Header") 1284 want := "X-Value" 1285 if got != want { 1286 t.Errorf("Header(XHeader) = %q; want %q", got, want) 1287 } 1288 1289 if !ascii.EqualFold(upgradeType(res.Header), "websocket") { 1290 t.Fatalf("not websocket upgrade; got %#v", res.Header) 1291 } 1292 rwc, ok := res.Body.(io.ReadWriteCloser) 1293 if !ok { 1294 t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) 1295 } 1296 defer rwc.Close() 1297 1298 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1299 t.Errorf("response X-Modified header = %q; want %q", got, want) 1300 } 1301 1302 io.WriteString(rwc, "Hello\n") 1303 bs := bufio.NewScanner(rwc) 1304 if !bs.Scan() { 1305 t.Fatalf("Scan: %v", bs.Err()) 1306 } 1307 got = bs.Text() 1308 want = `backend got "Hello"` 1309 if got != want { 1310 t.Errorf("got %#q, want %#q", got, want) 1311 } 1312 } 1313 1314 func TestReverseProxyWebSocketCancellation(t *testing.T) { 1315 n := 5 1316 triggerCancelCh := make(chan bool, n) 1317 nthResponse := func(i int) string { 1318 return fmt.Sprintf("backend response #%d\n", i) 1319 } 1320 terminalMsg := "final message" 1321 1322 cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1323 if g, ws := upgradeType(r.Header), "websocket"; g != ws { 1324 t.Errorf("Unexpected upgrade type %q, want %q", g, ws) 1325 http.Error(w, "Unexpected request", 400) 1326 return 1327 } 1328 conn, bufrw, err := w.(http.Hijacker).Hijack() 1329 if err != nil { 1330 t.Error(err) 1331 return 1332 } 1333 defer conn.Close() 1334 1335 upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" 1336 if _, err := io.WriteString(conn, upgradeMsg); err != nil { 1337 t.Error(err) 1338 return 1339 } 1340 if _, _, err := bufrw.ReadLine(); err != nil { 1341 t.Errorf("Failed to read line from client: %v", err) 1342 return 1343 } 1344 1345 for i := 0; i < n; i++ { 1346 if _, err := bufrw.WriteString(nthResponse(i)); err != nil { 1347 select { 1348 case <-triggerCancelCh: 1349 default: 1350 t.Errorf("Writing response #%d failed: %v", i, err) 1351 } 1352 return 1353 } 1354 bufrw.Flush() 1355 time.Sleep(time.Second) 1356 } 1357 if _, err := bufrw.WriteString(terminalMsg); err != nil { 1358 select { 1359 case <-triggerCancelCh: 1360 default: 1361 t.Errorf("Failed to write terminal message: %v", err) 1362 } 1363 } 1364 bufrw.Flush() 1365 })) 1366 defer cst.Close() 1367 1368 backendURL, _ := url.Parse(cst.URL) 1369 rproxy := NewSingleHostReverseProxy(backendURL) 1370 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1371 rproxy.ModifyResponse = func(res *http.Response) error { 1372 res.Header.Add("X-Modified", "true") 1373 return nil 1374 } 1375 1376 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1377 rw.Header().Set("X-Header", "X-Value") 1378 ctx, cancel := context.WithCancel(req.Context()) 1379 go func() { 1380 <-triggerCancelCh 1381 cancel() 1382 }() 1383 rproxy.ServeHTTP(rw, req.WithContext(ctx)) 1384 }) 1385 1386 frontendProxy := httptest.NewServer(handler) 1387 defer frontendProxy.Close() 1388 1389 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1390 req.Header.Set("Connection", "Upgrade") 1391 req.Header.Set("Upgrade", "websocket") 1392 1393 res, err := frontendProxy.Client().Do(req) 1394 if err != nil { 1395 t.Fatalf("Dialing to frontend proxy: %v", err) 1396 } 1397 defer res.Body.Close() 1398 if g, w := res.StatusCode, 101; g != w { 1399 t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) 1400 } 1401 1402 if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { 1403 t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) 1404 } 1405 1406 if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) { 1407 t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) 1408 } 1409 1410 rwc, ok := res.Body.(io.ReadWriteCloser) 1411 if !ok { 1412 t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) 1413 } 1414 1415 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1416 t.Errorf("response X-Modified header = %q; want %q", got, want) 1417 } 1418 1419 if _, err := io.WriteString(rwc, "Hello\n"); err != nil { 1420 t.Fatalf("Failed to write first message: %v", err) 1421 } 1422 1423 // Read loop. 1424 1425 br := bufio.NewReader(rwc) 1426 for { 1427 line, err := br.ReadString('\n') 1428 switch { 1429 case line == terminalMsg: // this case before "err == io.EOF" 1430 t.Fatalf("The websocket request was not canceled, unfortunately!") 1431 1432 case err == io.EOF: 1433 return 1434 1435 case err != nil: 1436 t.Fatalf("Unexpected error: %v", err) 1437 1438 case line == nthResponse(0): // We've gotten the first response back 1439 // Let's trigger a cancel. 1440 close(triggerCancelCh) 1441 } 1442 } 1443 } 1444 1445 func TestUnannouncedTrailer(t *testing.T) { 1446 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1447 w.WriteHeader(http.StatusOK) 1448 w.(http.Flusher).Flush() 1449 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 1450 })) 1451 defer backend.Close() 1452 backendURL, err := url.Parse(backend.URL) 1453 if err != nil { 1454 t.Fatal(err) 1455 } 1456 proxyHandler := NewSingleHostReverseProxy(backendURL) 1457 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1458 frontend := httptest.NewServer(proxyHandler) 1459 defer frontend.Close() 1460 frontendClient := frontend.Client() 1461 1462 res, err := frontendClient.Get(frontend.URL) 1463 if err != nil { 1464 t.Fatalf("Get: %v", err) 1465 } 1466 1467 io.ReadAll(res.Body) 1468 1469 if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { 1470 t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) 1471 } 1472 1473 } 1474 1475 func TestSingleJoinSlash(t *testing.T) { 1476 tests := []struct { 1477 slasha string 1478 slashb string 1479 expected string 1480 }{ 1481 {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1482 {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1483 {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, 1484 {"https://www.google.com", "", "https://www.google.com/"}, 1485 {"", "favicon.ico", "/favicon.ico"}, 1486 } 1487 for _, tt := range tests { 1488 if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { 1489 t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", 1490 tt.slasha, 1491 tt.slashb, 1492 tt.expected, 1493 got) 1494 } 1495 } 1496 } 1497 1498 func TestJoinURLPath(t *testing.T) { 1499 tests := []struct { 1500 a *url.URL 1501 b *url.URL 1502 wantPath string 1503 wantRaw string 1504 }{ 1505 {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, 1506 {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, 1507 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1508 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1509 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, 1510 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, 1511 } 1512 1513 for _, tt := range tests { 1514 p, rp := joinURLPath(tt.a, tt.b) 1515 if p != tt.wantPath || rp != tt.wantRaw { 1516 t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", 1517 tt.a.Path, tt.a.RawPath, 1518 tt.b.Path, tt.b.RawPath, 1519 tt.wantPath, tt.wantRaw, 1520 p, rp) 1521 } 1522 } 1523 }