github.com/useflyent/fhttp@v0.0.0-20211004035111-333f430cfbbf/httputil/reverseproxy_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Reverse proxy tests. 6 7 package httputil 8 9 import ( 10 "bufio" 11 "bytes" 12 "context" 13 "errors" 14 "fmt" 15 "io" 16 "log" 17 "net/url" 18 "os" 19 "reflect" 20 "sort" 21 "strconv" 22 "strings" 23 "sync" 24 "testing" 25 "time" 26 27 http "github.com/useflyent/fhttp" 28 "github.com/useflyent/fhttp/httptest" 29 ) 30 31 const fakeHopHeader = "X-Fake-Hop-Header-For-Test" 32 33 func init() { 34 inOurTests = true 35 hopHeaders = append(hopHeaders, fakeHopHeader) 36 } 37 38 func TestReverseProxy(t *testing.T) { 39 const backendResponse = "I am the backend" 40 const backendStatus = 404 41 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 42 if r.Method == "GET" && r.FormValue("mode") == "hangup" { 43 c, _, _ := w.(http.Hijacker).Hijack() 44 c.Close() 45 return 46 } 47 if len(r.TransferEncoding) > 0 { 48 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) 49 } 50 if r.Header.Get("X-Forwarded-For") == "" { 51 t.Errorf("didn't get X-Forwarded-For header") 52 } 53 if c := r.Header.Get("Connection"); c != "" { 54 t.Errorf("handler got Connection header value %q", c) 55 } 56 if c := r.Header.Get("Te"); c != "trailers" { 57 t.Errorf("handler got Te header value %q; want 'trailers'", c) 58 } 59 if c := r.Header.Get("Upgrade"); c != "" { 60 t.Errorf("handler got Upgrade header value %q", c) 61 } 62 if c := r.Header.Get("Proxy-Connection"); c != "" { 63 t.Errorf("handler got Proxy-Connection header value %q", c) 64 } 65 if g, e := r.Host, "some-name"; g != e { 66 t.Errorf("backend got Host header %q, want %q", g, e) 67 } 68 w.Header().Set("Trailers", "not a special header field name") 69 w.Header().Set("Trailer", "X-Trailer") 70 w.Header().Set("X-Foo", "bar") 71 w.Header().Set("Upgrade", "foo") 72 w.Header().Set(fakeHopHeader, "foo") 73 w.Header().Add("X-Multi-Value", "foo") 74 w.Header().Add("X-Multi-Value", "bar") 75 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) 76 w.WriteHeader(backendStatus) 77 w.Write([]byte(backendResponse)) 78 w.Header().Set("X-Trailer", "trailer_value") 79 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 80 })) 81 defer backend.Close() 82 backendURL, err := url.Parse(backend.URL) 83 if err != nil { 84 t.Fatal(err) 85 } 86 proxyHandler := NewSingleHostReverseProxy(backendURL) 87 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 88 frontend := httptest.NewServer(proxyHandler) 89 defer frontend.Close() 90 frontendClient := frontend.Client() 91 92 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 93 getReq.Host = "some-name" 94 getReq.Header.Set("Connection", "close") 95 getReq.Header.Set("Te", "trailers") 96 getReq.Header.Set("Proxy-Connection", "should be deleted") 97 getReq.Header.Set("Upgrade", "foo") 98 getReq.Close = true 99 res, err := frontendClient.Do(getReq) 100 if err != nil { 101 t.Fatalf("Get: %v", err) 102 } 103 if g, e := res.StatusCode, backendStatus; g != e { 104 t.Errorf("got res.StatusCode %d; expected %d", g, e) 105 } 106 if g, e := res.Header.Get("X-Foo"), "bar"; g != e { 107 t.Errorf("got X-Foo %q; expected %q", g, e) 108 } 109 if c := res.Header.Get(fakeHopHeader); c != "" { 110 t.Errorf("got %s header value %q", fakeHopHeader, c) 111 } 112 if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { 113 t.Errorf("header Trailers = %q; want %q", g, e) 114 } 115 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { 116 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) 117 } 118 if g, e := len(res.Header["Set-Cookie"]), 1; g != e { 119 t.Fatalf("got %d SetCookies, want %d", g, e) 120 } 121 if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { 122 t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) 123 } 124 if cookie := res.Cookies()[0]; cookie.Name != "flavor" { 125 t.Errorf("unexpected cookie %q", cookie.Name) 126 } 127 bodyBytes, _ := io.ReadAll(res.Body) 128 if g, e := string(bodyBytes), backendResponse; g != e { 129 t.Errorf("got body %q; expected %q", g, e) 130 } 131 if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { 132 t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) 133 } 134 if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { 135 t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) 136 } 137 138 // Test that a backend failing to be reached or one which doesn't return 139 // a response results in a StatusBadGateway. 140 getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) 141 getReq.Close = true 142 res, err = frontendClient.Do(getReq) 143 if err != nil { 144 t.Fatal(err) 145 } 146 res.Body.Close() 147 if res.StatusCode != http.StatusBadGateway { 148 t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) 149 } 150 151 } 152 153 // Issue 16875: remove any proxied headers mentioned in the "Connection" 154 // header value. 155 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { 156 const fakeConnectionToken = "X-Fake-Connection-Token" 157 const backendResponse = "I am the backend" 158 159 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 160 // in the Request's Connection header. 161 const someConnHeader = "X-Some-Conn-Header" 162 163 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 164 if c := r.Header.Get("Connection"); c != "" { 165 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 166 } 167 if c := r.Header.Get(fakeConnectionToken); c != "" { 168 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 169 } 170 if c := r.Header.Get(someConnHeader); c != "" { 171 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 172 } 173 w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) 174 w.Header().Add("Connection", someConnHeader) 175 w.Header().Set(someConnHeader, "should be deleted") 176 w.Header().Set(fakeConnectionToken, "should be deleted") 177 io.WriteString(w, backendResponse) 178 })) 179 defer backend.Close() 180 backendURL, err := url.Parse(backend.URL) 181 if err != nil { 182 t.Fatal(err) 183 } 184 proxyHandler := NewSingleHostReverseProxy(backendURL) 185 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 186 proxyHandler.ServeHTTP(w, r) 187 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 188 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 189 } 190 if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { 191 t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") 192 } 193 c := r.Header["Connection"] 194 var cf []string 195 for _, f := range c { 196 for _, sf := range strings.Split(f, ",") { 197 if sf = strings.TrimSpace(sf); sf != "" { 198 cf = append(cf, sf) 199 } 200 } 201 } 202 sort.Strings(cf) 203 expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} 204 sort.Strings(expectedValues) 205 if !reflect.DeepEqual(cf, expectedValues) { 206 t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) 207 } 208 })) 209 defer frontend.Close() 210 211 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 212 getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) 213 getReq.Header.Add("Connection", someConnHeader) 214 getReq.Header.Set(someConnHeader, "should be deleted") 215 getReq.Header.Set(fakeConnectionToken, "should be deleted") 216 res, err := frontend.Client().Do(getReq) 217 if err != nil { 218 t.Fatalf("Get: %v", err) 219 } 220 defer res.Body.Close() 221 bodyBytes, err := io.ReadAll(res.Body) 222 if err != nil { 223 t.Fatalf("reading body: %v", err) 224 } 225 if got, want := string(bodyBytes), backendResponse; got != want { 226 t.Errorf("got body %q; want %q", got, want) 227 } 228 if c := res.Header.Get("Connection"); c != "" { 229 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 230 } 231 if c := res.Header.Get(someConnHeader); c != "" { 232 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 233 } 234 if c := res.Header.Get(fakeConnectionToken); c != "" { 235 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 236 } 237 } 238 239 func TestXForwardedFor(t *testing.T) { 240 const prevForwardedFor = "client ip" 241 const backendResponse = "I am the backend" 242 const backendStatus = 404 243 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 244 if r.Header.Get("X-Forwarded-For") == "" { 245 t.Errorf("didn't get X-Forwarded-For header") 246 } 247 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { 248 t.Errorf("X-Forwarded-For didn't contain prior data") 249 } 250 w.WriteHeader(backendStatus) 251 w.Write([]byte(backendResponse)) 252 })) 253 defer backend.Close() 254 backendURL, err := url.Parse(backend.URL) 255 if err != nil { 256 t.Fatal(err) 257 } 258 proxyHandler := NewSingleHostReverseProxy(backendURL) 259 frontend := httptest.NewServer(proxyHandler) 260 defer frontend.Close() 261 262 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 263 getReq.Host = "some-name" 264 getReq.Header.Set("Connection", "close") 265 getReq.Header.Set("X-Forwarded-For", prevForwardedFor) 266 getReq.Close = true 267 res, err := frontend.Client().Do(getReq) 268 if err != nil { 269 t.Fatalf("Get: %v", err) 270 } 271 if g, e := res.StatusCode, backendStatus; g != e { 272 t.Errorf("got res.StatusCode %d; expected %d", g, e) 273 } 274 bodyBytes, _ := io.ReadAll(res.Body) 275 if g, e := string(bodyBytes), backendResponse; g != e { 276 t.Errorf("got body %q; expected %q", g, e) 277 } 278 } 279 280 // Issue 38079: don't append to X-Forwarded-For if it's present but nil 281 func TestXForwardedFor_Omit(t *testing.T) { 282 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 283 if v := r.Header.Get("X-Forwarded-For"); v != "" { 284 t.Errorf("got X-Forwarded-For header: %q", v) 285 } 286 w.Write([]byte("hi")) 287 })) 288 defer backend.Close() 289 backendURL, err := url.Parse(backend.URL) 290 if err != nil { 291 t.Fatal(err) 292 } 293 proxyHandler := NewSingleHostReverseProxy(backendURL) 294 frontend := httptest.NewServer(proxyHandler) 295 defer frontend.Close() 296 297 oldDirector := proxyHandler.Director 298 proxyHandler.Director = func(r *http.Request) { 299 r.Header["X-Forwarded-For"] = nil 300 oldDirector(r) 301 } 302 303 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 304 getReq.Host = "some-name" 305 getReq.Close = true 306 res, err := frontend.Client().Do(getReq) 307 if err != nil { 308 t.Fatalf("Get: %v", err) 309 } 310 res.Body.Close() 311 } 312 313 var proxyQueryTests = []struct { 314 baseSuffix string // suffix to add to backend URL 315 reqSuffix string // suffix to add to frontend's request URL 316 want string // what backend should see for final request URL (without ?) 317 }{ 318 {"", "", ""}, 319 {"?sta=tic", "?us=er", "sta=tic&us=er"}, 320 {"", "?us=er", "us=er"}, 321 {"?sta=tic", "", "sta=tic"}, 322 } 323 324 func TestReverseProxyQuery(t *testing.T) { 325 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 326 w.Header().Set("X-Got-Query", r.URL.RawQuery) 327 w.Write([]byte("hi")) 328 })) 329 defer backend.Close() 330 331 for i, tt := range proxyQueryTests { 332 backendURL, err := url.Parse(backend.URL + tt.baseSuffix) 333 if err != nil { 334 t.Fatal(err) 335 } 336 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) 337 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) 338 req.Close = true 339 res, err := frontend.Client().Do(req) 340 if err != nil { 341 t.Fatalf("%d. Get: %v", i, err) 342 } 343 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { 344 t.Errorf("%d. got query %q; expected %q", i, g, e) 345 } 346 res.Body.Close() 347 frontend.Close() 348 } 349 } 350 351 func TestReverseProxyFlushInterval(t *testing.T) { 352 const expected = "hi" 353 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 354 w.Write([]byte(expected)) 355 })) 356 defer backend.Close() 357 358 backendURL, err := url.Parse(backend.URL) 359 if err != nil { 360 t.Fatal(err) 361 } 362 363 proxyHandler := NewSingleHostReverseProxy(backendURL) 364 proxyHandler.FlushInterval = time.Microsecond 365 366 frontend := httptest.NewServer(proxyHandler) 367 defer frontend.Close() 368 369 req, _ := http.NewRequest("GET", frontend.URL, nil) 370 req.Close = true 371 res, err := frontend.Client().Do(req) 372 if err != nil { 373 t.Fatalf("Get: %v", err) 374 } 375 defer res.Body.Close() 376 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { 377 t.Errorf("got body %q; expected %q", bodyBytes, expected) 378 } 379 } 380 381 func TestReverseProxyFlushIntervalHeaders(t *testing.T) { 382 const expected = "hi" 383 stopCh := make(chan struct{}) 384 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 385 w.Header().Add("MyHeader", expected) 386 w.WriteHeader(200) 387 w.(http.Flusher).Flush() 388 <-stopCh 389 })) 390 defer backend.Close() 391 defer close(stopCh) 392 393 backendURL, err := url.Parse(backend.URL) 394 if err != nil { 395 t.Fatal(err) 396 } 397 398 proxyHandler := NewSingleHostReverseProxy(backendURL) 399 proxyHandler.FlushInterval = time.Microsecond 400 401 frontend := httptest.NewServer(proxyHandler) 402 defer frontend.Close() 403 404 req, _ := http.NewRequest("GET", frontend.URL, nil) 405 req.Close = true 406 407 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) 408 defer cancel() 409 req = req.WithContext(ctx) 410 411 res, err := frontend.Client().Do(req) 412 if err != nil { 413 t.Fatalf("Get: %v", err) 414 } 415 defer res.Body.Close() 416 417 if res.Header.Get("MyHeader") != expected { 418 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) 419 } 420 } 421 422 func TestReverseProxyCancellation(t *testing.T) { 423 const backendResponse = "I am the backend" 424 425 reqInFlight := make(chan struct{}) 426 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 427 close(reqInFlight) // cause the client to cancel its request 428 429 select { 430 case <-time.After(10 * time.Second): 431 // Note: this should only happen in broken implementations, and the 432 // closenotify case should be instantaneous. 433 t.Error("Handler never saw CloseNotify") 434 return 435 case <-w.(http.CloseNotifier).CloseNotify(): 436 } 437 438 w.WriteHeader(http.StatusOK) 439 w.Write([]byte(backendResponse)) 440 })) 441 442 defer backend.Close() 443 444 backend.Config.ErrorLog = log.New(io.Discard, "", 0) 445 446 backendURL, err := url.Parse(backend.URL) 447 if err != nil { 448 t.Fatal(err) 449 } 450 451 proxyHandler := NewSingleHostReverseProxy(backendURL) 452 453 // Discards errors of the form: 454 // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection 455 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) 456 457 frontend := httptest.NewServer(proxyHandler) 458 defer frontend.Close() 459 frontendClient := frontend.Client() 460 461 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 462 go func() { 463 <-reqInFlight 464 frontendClient.Transport.(*http.Transport).CancelRequest(getReq) 465 }() 466 res, err := frontendClient.Do(getReq) 467 if res != nil { 468 t.Errorf("got response %v; want nil", res.Status) 469 } 470 if err == nil { 471 // This should be an error like: 472 // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079: 473 // use of closed network connection 474 t.Error("Server.Client().Do() returned nil error; want non-nil error") 475 } 476 } 477 478 func req(t *testing.T, v string) *http.Request { 479 req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) 480 if err != nil { 481 t.Fatal(err) 482 } 483 return req 484 } 485 486 // Issue 12344 487 func TestNilBody(t *testing.T) { 488 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 489 w.Write([]byte("hi")) 490 })) 491 defer backend.Close() 492 493 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 494 backURL, _ := url.Parse(backend.URL) 495 rp := NewSingleHostReverseProxy(backURL) 496 r := req(t, "GET / HTTP/1.0\r\n\r\n") 497 r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working 498 rp.ServeHTTP(w, r) 499 })) 500 defer frontend.Close() 501 502 res, err := http.Get(frontend.URL) 503 if err != nil { 504 t.Fatal(err) 505 } 506 defer res.Body.Close() 507 slurp, err := io.ReadAll(res.Body) 508 if err != nil { 509 t.Fatal(err) 510 } 511 if string(slurp) != "hi" { 512 t.Errorf("Got %q; want %q", slurp, "hi") 513 } 514 } 515 516 // Issue 15524 517 func TestUserAgentHeader(t *testing.T) { 518 const explicitUA = "explicit UA" 519 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 520 if r.URL.Path == "/noua" { 521 if c := r.Header.Get("User-Agent"); c != "" { 522 t.Errorf("handler got non-empty User-Agent header %q", c) 523 } 524 return 525 } 526 if c := r.Header.Get("User-Agent"); c != explicitUA { 527 t.Errorf("handler got unexpected User-Agent header %q", c) 528 } 529 })) 530 defer backend.Close() 531 backendURL, err := url.Parse(backend.URL) 532 if err != nil { 533 t.Fatal(err) 534 } 535 proxyHandler := NewSingleHostReverseProxy(backendURL) 536 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 537 frontend := httptest.NewServer(proxyHandler) 538 defer frontend.Close() 539 frontendClient := frontend.Client() 540 541 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 542 getReq.Header.Set("User-Agent", explicitUA) 543 getReq.Close = true 544 res, err := frontendClient.Do(getReq) 545 if err != nil { 546 t.Fatalf("Get: %v", err) 547 } 548 res.Body.Close() 549 550 getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) 551 getReq.Header.Set("User-Agent", "") 552 getReq.Close = true 553 res, err = frontendClient.Do(getReq) 554 if err != nil { 555 t.Fatalf("Get: %v", err) 556 } 557 res.Body.Close() 558 } 559 560 type bufferPool struct { 561 get func() []byte 562 put func([]byte) 563 } 564 565 func (bp bufferPool) Get() []byte { return bp.get() } 566 func (bp bufferPool) Put(v []byte) { bp.put(v) } 567 568 func TestReverseProxyGetPutBuffer(t *testing.T) { 569 const msg = "hi" 570 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 571 io.WriteString(w, msg) 572 })) 573 defer backend.Close() 574 575 backendURL, err := url.Parse(backend.URL) 576 if err != nil { 577 t.Fatal(err) 578 } 579 580 var ( 581 mu sync.Mutex 582 log []string 583 ) 584 addLog := func(event string) { 585 mu.Lock() 586 defer mu.Unlock() 587 log = append(log, event) 588 } 589 rp := NewSingleHostReverseProxy(backendURL) 590 const size = 1234 591 rp.BufferPool = bufferPool{ 592 get: func() []byte { 593 addLog("getBuf") 594 return make([]byte, size) 595 }, 596 put: func(p []byte) { 597 addLog("putBuf-" + strconv.Itoa(len(p))) 598 }, 599 } 600 frontend := httptest.NewServer(rp) 601 defer frontend.Close() 602 603 req, _ := http.NewRequest("GET", frontend.URL, nil) 604 req.Close = true 605 res, err := frontend.Client().Do(req) 606 if err != nil { 607 t.Fatalf("Get: %v", err) 608 } 609 slurp, err := io.ReadAll(res.Body) 610 res.Body.Close() 611 if err != nil { 612 t.Fatalf("reading body: %v", err) 613 } 614 if string(slurp) != msg { 615 t.Errorf("msg = %q; want %q", slurp, msg) 616 } 617 wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} 618 mu.Lock() 619 defer mu.Unlock() 620 if !reflect.DeepEqual(log, wantLog) { 621 t.Errorf("Log events = %q; want %q", log, wantLog) 622 } 623 } 624 625 func TestReverseProxy_Post(t *testing.T) { 626 const backendResponse = "I am the backend" 627 const backendStatus = 200 628 var requestBody = bytes.Repeat([]byte("a"), 1<<20) 629 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 630 slurp, err := io.ReadAll(r.Body) 631 if err != nil { 632 t.Errorf("Backend body read = %v", err) 633 } 634 if len(slurp) != len(requestBody) { 635 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) 636 } 637 if !bytes.Equal(slurp, requestBody) { 638 t.Error("Backend read wrong request body.") // 1MB; omitting details 639 } 640 w.Write([]byte(backendResponse)) 641 })) 642 defer backend.Close() 643 backendURL, err := url.Parse(backend.URL) 644 if err != nil { 645 t.Fatal(err) 646 } 647 proxyHandler := NewSingleHostReverseProxy(backendURL) 648 frontend := httptest.NewServer(proxyHandler) 649 defer frontend.Close() 650 651 postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) 652 res, err := frontend.Client().Do(postReq) 653 if err != nil { 654 t.Fatalf("Do: %v", err) 655 } 656 if g, e := res.StatusCode, backendStatus; g != e { 657 t.Errorf("got res.StatusCode %d; expected %d", g, e) 658 } 659 bodyBytes, _ := io.ReadAll(res.Body) 660 if g, e := string(bodyBytes), backendResponse; g != e { 661 t.Errorf("got body %q; expected %q", g, e) 662 } 663 } 664 665 type RoundTripperFunc func(*http.Request) (*http.Response, error) 666 667 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 668 return fn(req) 669 } 670 671 // Issue 16036: send a Request with a nil Body when possible 672 func TestReverseProxy_NilBody(t *testing.T) { 673 backendURL, _ := url.Parse("http://fake.tld/") 674 proxyHandler := NewSingleHostReverseProxy(backendURL) 675 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 676 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 677 if req.Body != nil { 678 t.Error("Body != nil; want a nil Body") 679 } 680 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 681 }) 682 frontend := httptest.NewServer(proxyHandler) 683 defer frontend.Close() 684 685 res, err := frontend.Client().Get(frontend.URL) 686 if err != nil { 687 t.Fatal(err) 688 } 689 defer res.Body.Close() 690 if res.StatusCode != 502 { 691 t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) 692 } 693 } 694 695 // Issue 33142: always allocate the request headers 696 func TestReverseProxy_AllocatedHeader(t *testing.T) { 697 proxyHandler := new(ReverseProxy) 698 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 699 proxyHandler.Director = func(*http.Request) {} // noop 700 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 701 if req.Header == nil { 702 t.Error("Header == nil; want a non-nil Header") 703 } 704 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 705 }) 706 707 proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ 708 Method: "GET", 709 URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, 710 Proto: "HTTP/1.0", 711 ProtoMajor: 1, 712 }) 713 } 714 715 // Issue 14237. Test ModifyResponse and that an error from it 716 // causes the proxy to return StatusBadGateway, or StatusOK otherwise. 717 func TestReverseProxyModifyResponse(t *testing.T) { 718 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 719 w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) 720 })) 721 defer backendServer.Close() 722 723 rpURL, _ := url.Parse(backendServer.URL) 724 rproxy := NewSingleHostReverseProxy(rpURL) 725 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 726 rproxy.ModifyResponse = func(resp *http.Response) error { 727 if resp.Header.Get("X-Hit-Mod") != "true" { 728 return fmt.Errorf("tried to by-pass proxy") 729 } 730 return nil 731 } 732 733 frontendProxy := httptest.NewServer(rproxy) 734 defer frontendProxy.Close() 735 736 tests := []struct { 737 url string 738 wantCode int 739 }{ 740 {frontendProxy.URL + "/mod", http.StatusOK}, 741 {frontendProxy.URL + "/schedule", http.StatusBadGateway}, 742 } 743 744 for i, tt := range tests { 745 resp, err := http.Get(tt.url) 746 if err != nil { 747 t.Fatalf("failed to reach proxy: %v", err) 748 } 749 if g, e := resp.StatusCode, tt.wantCode; g != e { 750 t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) 751 } 752 resp.Body.Close() 753 } 754 } 755 756 type failingRoundTripper struct{} 757 758 func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 759 return nil, errors.New("some error") 760 } 761 762 type staticResponseRoundTripper struct{ res *http.Response } 763 764 func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 765 return rt.res, nil 766 } 767 768 func TestReverseProxyErrorHandler(t *testing.T) { 769 tests := []struct { 770 name string 771 wantCode int 772 errorHandler func(http.ResponseWriter, *http.Request, error) 773 transport http.RoundTripper // defaults to failingRoundTripper 774 modifyResponse func(*http.Response) error 775 }{ 776 { 777 name: "default", 778 wantCode: http.StatusBadGateway, 779 }, 780 { 781 name: "errorhandler", 782 wantCode: http.StatusTeapot, 783 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 784 }, 785 { 786 name: "modifyresponse_noerr", 787 transport: staticResponseRoundTripper{ 788 &http.Response{StatusCode: 345, Body: http.NoBody}, 789 }, 790 modifyResponse: func(res *http.Response) error { 791 res.StatusCode++ 792 return nil 793 }, 794 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 795 wantCode: 346, 796 }, 797 { 798 name: "modifyresponse_err", 799 transport: staticResponseRoundTripper{ 800 &http.Response{StatusCode: 345, Body: http.NoBody}, 801 }, 802 modifyResponse: func(res *http.Response) error { 803 res.StatusCode++ 804 return errors.New("some error to trigger errorHandler") 805 }, 806 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 807 wantCode: http.StatusTeapot, 808 }, 809 } 810 811 for _, tt := range tests { 812 t.Run(tt.name, func(t *testing.T) { 813 target := &url.URL{ 814 Scheme: "http", 815 Host: "dummy.tld", 816 Path: "/", 817 } 818 rproxy := NewSingleHostReverseProxy(target) 819 rproxy.Transport = tt.transport 820 rproxy.ModifyResponse = tt.modifyResponse 821 if rproxy.Transport == nil { 822 rproxy.Transport = failingRoundTripper{} 823 } 824 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 825 if tt.errorHandler != nil { 826 rproxy.ErrorHandler = tt.errorHandler 827 } 828 frontendProxy := httptest.NewServer(rproxy) 829 defer frontendProxy.Close() 830 831 resp, err := http.Get(frontendProxy.URL + "/test") 832 if err != nil { 833 t.Fatalf("failed to reach proxy: %v", err) 834 } 835 if g, e := resp.StatusCode, tt.wantCode; g != e { 836 t.Errorf("got res.StatusCode %d; expected %d", g, e) 837 } 838 resp.Body.Close() 839 }) 840 } 841 } 842 843 // Issue 16659: log errors from short read 844 func TestReverseProxy_CopyBuffer(t *testing.T) { 845 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 846 out := "this call was relayed by the reverse proxy" 847 // Coerce a wrong content length to induce io.UnexpectedEOF 848 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 849 fmt.Fprintln(w, out) 850 })) 851 defer backendServer.Close() 852 853 rpURL, err := url.Parse(backendServer.URL) 854 if err != nil { 855 t.Fatal(err) 856 } 857 858 var proxyLog bytes.Buffer 859 rproxy := NewSingleHostReverseProxy(rpURL) 860 rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) 861 donec := make(chan bool, 1) 862 frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 863 defer func() { donec <- true }() 864 rproxy.ServeHTTP(w, r) 865 })) 866 defer frontendProxy.Close() 867 868 if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { 869 t.Fatalf("want non-nil error") 870 } 871 // The race detector complains about the proxyLog usage in logf in copyBuffer 872 // and our usage below with proxyLog.Bytes() so we're explicitly using a 873 // channel to ensure that the ReverseProxy's ServeHTTP is done before we 874 // continue after Get. 875 <-donec 876 877 expected := []string{ 878 "EOF", 879 "read", 880 } 881 for _, phrase := range expected { 882 if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { 883 t.Errorf("expected log to contain phrase %q", phrase) 884 } 885 } 886 } 887 888 type staticTransport struct { 889 res *http.Response 890 } 891 892 func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { 893 return t.res, nil 894 } 895 896 func BenchmarkServeHTTP(b *testing.B) { 897 res := &http.Response{ 898 StatusCode: 200, 899 Body: io.NopCloser(strings.NewReader("")), 900 } 901 proxy := &ReverseProxy{ 902 Director: func(*http.Request) {}, 903 Transport: &staticTransport{res}, 904 } 905 906 w := httptest.NewRecorder() 907 r := httptest.NewRequest("GET", "/", nil) 908 909 b.ReportAllocs() 910 for i := 0; i < b.N; i++ { 911 proxy.ServeHTTP(w, r) 912 } 913 } 914 915 func TestServeHTTPDeepCopy(t *testing.T) { 916 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 917 w.Write([]byte("Hello Gopher!")) 918 })) 919 defer backend.Close() 920 backendURL, err := url.Parse(backend.URL) 921 if err != nil { 922 t.Fatal(err) 923 } 924 925 type result struct { 926 before, after string 927 } 928 929 resultChan := make(chan result, 1) 930 proxyHandler := NewSingleHostReverseProxy(backendURL) 931 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 932 before := r.URL.String() 933 proxyHandler.ServeHTTP(w, r) 934 after := r.URL.String() 935 resultChan <- result{before: before, after: after} 936 })) 937 defer frontend.Close() 938 939 want := result{before: "/", after: "/"} 940 941 res, err := frontend.Client().Get(frontend.URL) 942 if err != nil { 943 t.Fatalf("Do: %v", err) 944 } 945 res.Body.Close() 946 947 got := <-resultChan 948 if got != want { 949 t.Errorf("got = %+v; want = %+v", got, want) 950 } 951 } 952 953 // Issue 18327: verify we always do a deep copy of the Request.Header map 954 // before any mutations. 955 func TestClonesRequestHeaders(t *testing.T) { 956 log.SetOutput(io.Discard) 957 defer log.SetOutput(os.Stderr) 958 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 959 req.RemoteAddr = "1.2.3.4:56789" 960 rp := &ReverseProxy{ 961 Director: func(req *http.Request) { 962 req.Header.Set("From-Director", "1") 963 }, 964 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { 965 if v := req.Header.Get("From-Director"); v != "1" { 966 t.Errorf("From-Directory value = %q; want 1", v) 967 } 968 return nil, io.EOF 969 }), 970 } 971 rp.ServeHTTP(httptest.NewRecorder(), req) 972 973 if req.Header.Get("From-Director") == "1" { 974 t.Error("Director header mutation modified caller's request") 975 } 976 if req.Header.Get("X-Forwarded-For") != "" { 977 t.Error("X-Forward-For header mutation modified caller's request") 978 } 979 980 } 981 982 type roundTripperFunc func(req *http.Request) (*http.Response, error) 983 984 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 985 return fn(req) 986 } 987 988 func TestModifyResponseClosesBody(t *testing.T) { 989 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 990 req.RemoteAddr = "1.2.3.4:56789" 991 closeCheck := new(checkCloser) 992 logBuf := new(bytes.Buffer) 993 outErr := errors.New("ModifyResponse error") 994 rp := &ReverseProxy{ 995 Director: func(req *http.Request) {}, 996 Transport: &staticTransport{&http.Response{ 997 StatusCode: 200, 998 Body: closeCheck, 999 }}, 1000 ErrorLog: log.New(logBuf, "", 0), 1001 ModifyResponse: func(*http.Response) error { 1002 return outErr 1003 }, 1004 } 1005 rec := httptest.NewRecorder() 1006 rp.ServeHTTP(rec, req) 1007 res := rec.Result() 1008 if g, e := res.StatusCode, http.StatusBadGateway; g != e { 1009 t.Errorf("got res.StatusCode %d; expected %d", g, e) 1010 } 1011 if !closeCheck.closed { 1012 t.Errorf("body should have been closed") 1013 } 1014 if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { 1015 t.Errorf("ErrorLog %q does not contain %q", g, e) 1016 } 1017 } 1018 1019 type checkCloser struct { 1020 closed bool 1021 } 1022 1023 func (cc *checkCloser) Close() error { 1024 cc.closed = true 1025 return nil 1026 } 1027 1028 func (cc *checkCloser) Read(b []byte) (int, error) { 1029 return len(b), nil 1030 } 1031 1032 // Issue 23643: panic on body copy error 1033 func TestReverseProxy_PanicBodyError(t *testing.T) { 1034 log.SetOutput(io.Discard) 1035 defer log.SetOutput(os.Stderr) 1036 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1037 out := "this call was relayed by the reverse proxy" 1038 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1039 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1040 fmt.Fprintln(w, out) 1041 })) 1042 defer backendServer.Close() 1043 1044 rpURL, err := url.Parse(backendServer.URL) 1045 if err != nil { 1046 t.Fatal(err) 1047 } 1048 1049 rproxy := NewSingleHostReverseProxy(rpURL) 1050 1051 // Ensure that the handler panics when the body read encounters an 1052 // io.ErrUnexpectedEOF 1053 defer func() { 1054 err := recover() 1055 if err == nil { 1056 t.Fatal("handler should have panicked") 1057 } 1058 if err != http.ErrAbortHandler { 1059 t.Fatal("expected ErrAbortHandler, got", err) 1060 } 1061 }() 1062 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1063 rproxy.ServeHTTP(httptest.NewRecorder(), req) 1064 } 1065 1066 func TestSelectFlushInterval(t *testing.T) { 1067 tests := []struct { 1068 name string 1069 p *ReverseProxy 1070 res *http.Response 1071 want time.Duration 1072 }{ 1073 { 1074 name: "default", 1075 res: &http.Response{}, 1076 p: &ReverseProxy{FlushInterval: 123}, 1077 want: 123, 1078 }, 1079 { 1080 name: "server-sent events overrides non-zero", 1081 res: &http.Response{ 1082 Header: http.Header{ 1083 "Content-Type": {"text/event-stream"}, 1084 }, 1085 }, 1086 p: &ReverseProxy{FlushInterval: 123}, 1087 want: -1, 1088 }, 1089 { 1090 name: "server-sent events overrides zero", 1091 res: &http.Response{ 1092 Header: http.Header{ 1093 "Content-Type": {"text/event-stream"}, 1094 }, 1095 }, 1096 p: &ReverseProxy{FlushInterval: 0}, 1097 want: -1, 1098 }, 1099 { 1100 name: "Content-Length: -1, overrides non-zero", 1101 res: &http.Response{ 1102 ContentLength: -1, 1103 }, 1104 p: &ReverseProxy{FlushInterval: 123}, 1105 want: -1, 1106 }, 1107 { 1108 name: "Content-Length: -1, overrides zero", 1109 res: &http.Response{ 1110 ContentLength: -1, 1111 }, 1112 p: &ReverseProxy{FlushInterval: 0}, 1113 want: -1, 1114 }, 1115 } 1116 for _, tt := range tests { 1117 t.Run(tt.name, func(t *testing.T) { 1118 got := tt.p.flushInterval(tt.res) 1119 if got != tt.want { 1120 t.Errorf("flushLatency = %v; want %v", got, tt.want) 1121 } 1122 }) 1123 } 1124 } 1125 1126 func TestReverseProxyWebSocket(t *testing.T) { 1127 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1128 if upgradeType(r.Header) != "websocket" { 1129 t.Error("unexpected backend request") 1130 http.Error(w, "unexpected request", 400) 1131 return 1132 } 1133 c, _, err := w.(http.Hijacker).Hijack() 1134 if err != nil { 1135 t.Error(err) 1136 return 1137 } 1138 defer c.Close() 1139 io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") 1140 bs := bufio.NewScanner(c) 1141 if !bs.Scan() { 1142 t.Errorf("backend failed to read line from client: %v", bs.Err()) 1143 return 1144 } 1145 fmt.Fprintf(c, "backend got %q\n", bs.Text()) 1146 })) 1147 defer backendServer.Close() 1148 1149 backURL, _ := url.Parse(backendServer.URL) 1150 rproxy := NewSingleHostReverseProxy(backURL) 1151 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1152 rproxy.ModifyResponse = func(res *http.Response) error { 1153 res.Header.Add("X-Modified", "true") 1154 return nil 1155 } 1156 1157 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1158 rw.Header().Set("X-Header", "X-Value") 1159 rproxy.ServeHTTP(rw, req) 1160 if got, want := rw.Header().Get("X-Modified"), "true"; got != want { 1161 t.Errorf("response writer X-Modified header = %q; want %q", got, want) 1162 } 1163 }) 1164 1165 frontendProxy := httptest.NewServer(handler) 1166 defer frontendProxy.Close() 1167 1168 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1169 req.Header.Set("Connection", "Upgrade") 1170 req.Header.Set("Upgrade", "websocket") 1171 1172 c := frontendProxy.Client() 1173 res, err := c.Do(req) 1174 if err != nil { 1175 t.Fatal(err) 1176 } 1177 if res.StatusCode != 101 { 1178 t.Fatalf("status = %v; want 101", res.Status) 1179 } 1180 1181 got := res.Header.Get("X-Header") 1182 want := "X-Value" 1183 if got != want { 1184 t.Errorf("Header(XHeader) = %q; want %q", got, want) 1185 } 1186 1187 if upgradeType(res.Header) != "websocket" { 1188 t.Fatalf("not websocket upgrade; got %#v", res.Header) 1189 } 1190 rwc, ok := res.Body.(io.ReadWriteCloser) 1191 if !ok { 1192 t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) 1193 } 1194 defer rwc.Close() 1195 1196 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1197 t.Errorf("response X-Modified header = %q; want %q", got, want) 1198 } 1199 1200 io.WriteString(rwc, "Hello\n") 1201 bs := bufio.NewScanner(rwc) 1202 if !bs.Scan() { 1203 t.Fatalf("Scan: %v", bs.Err()) 1204 } 1205 got = bs.Text() 1206 want = `backend got "Hello"` 1207 if got != want { 1208 t.Errorf("got %#q, want %#q", got, want) 1209 } 1210 } 1211 1212 func TestReverseProxyWebSocketCancelation(t *testing.T) { 1213 n := 5 1214 triggerCancelCh := make(chan bool, n) 1215 nthResponse := func(i int) string { 1216 return fmt.Sprintf("backend response #%d\n", i) 1217 } 1218 terminalMsg := "final message" 1219 1220 cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1221 if g, ws := upgradeType(r.Header), "websocket"; g != ws { 1222 t.Errorf("Unexpected upgrade type %q, want %q", g, ws) 1223 http.Error(w, "Unexpected request", 400) 1224 return 1225 } 1226 conn, bufrw, err := w.(http.Hijacker).Hijack() 1227 if err != nil { 1228 t.Error(err) 1229 return 1230 } 1231 defer conn.Close() 1232 1233 upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" 1234 if _, err := io.WriteString(conn, upgradeMsg); err != nil { 1235 t.Error(err) 1236 return 1237 } 1238 if _, _, err := bufrw.ReadLine(); err != nil { 1239 t.Errorf("Failed to read line from client: %v", err) 1240 return 1241 } 1242 1243 for i := 0; i < n; i++ { 1244 if _, err := bufrw.WriteString(nthResponse(i)); err != nil { 1245 select { 1246 case <-triggerCancelCh: 1247 default: 1248 t.Errorf("Writing response #%d failed: %v", i, err) 1249 } 1250 return 1251 } 1252 bufrw.Flush() 1253 time.Sleep(time.Second) 1254 } 1255 if _, err := bufrw.WriteString(terminalMsg); err != nil { 1256 select { 1257 case <-triggerCancelCh: 1258 default: 1259 t.Errorf("Failed to write terminal message: %v", err) 1260 } 1261 } 1262 bufrw.Flush() 1263 })) 1264 defer cst.Close() 1265 1266 backendURL, _ := url.Parse(cst.URL) 1267 rproxy := NewSingleHostReverseProxy(backendURL) 1268 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1269 rproxy.ModifyResponse = func(res *http.Response) error { 1270 res.Header.Add("X-Modified", "true") 1271 return nil 1272 } 1273 1274 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1275 rw.Header().Set("X-Header", "X-Value") 1276 ctx, cancel := context.WithCancel(req.Context()) 1277 go func() { 1278 <-triggerCancelCh 1279 cancel() 1280 }() 1281 rproxy.ServeHTTP(rw, req.WithContext(ctx)) 1282 }) 1283 1284 frontendProxy := httptest.NewServer(handler) 1285 defer frontendProxy.Close() 1286 1287 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1288 req.Header.Set("Connection", "Upgrade") 1289 req.Header.Set("Upgrade", "websocket") 1290 1291 res, err := frontendProxy.Client().Do(req) 1292 if err != nil { 1293 t.Fatalf("Dialing to frontend proxy: %v", err) 1294 } 1295 defer res.Body.Close() 1296 if g, w := res.StatusCode, 101; g != w { 1297 t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) 1298 } 1299 1300 if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { 1301 t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) 1302 } 1303 1304 if g, w := upgradeType(res.Header), "websocket"; g != w { 1305 t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) 1306 } 1307 1308 rwc, ok := res.Body.(io.ReadWriteCloser) 1309 if !ok { 1310 t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) 1311 } 1312 1313 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1314 t.Errorf("response X-Modified header = %q; want %q", got, want) 1315 } 1316 1317 if _, err := io.WriteString(rwc, "Hello\n"); err != nil { 1318 t.Fatalf("Failed to write first message: %v", err) 1319 } 1320 1321 // Read loop. 1322 1323 br := bufio.NewReader(rwc) 1324 for { 1325 line, err := br.ReadString('\n') 1326 switch { 1327 case line == terminalMsg: // this case before "err == io.EOF" 1328 t.Fatalf("The websocket request was not canceled, unfortunately!") 1329 1330 case err == io.EOF: 1331 return 1332 1333 case err != nil: 1334 t.Fatalf("Unexpected error: %v", err) 1335 1336 case line == nthResponse(0): // We've gotten the first response back 1337 // Let's trigger a cancel. 1338 close(triggerCancelCh) 1339 } 1340 } 1341 } 1342 1343 func TestUnannouncedTrailer(t *testing.T) { 1344 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1345 w.WriteHeader(http.StatusOK) 1346 w.(http.Flusher).Flush() 1347 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 1348 })) 1349 defer backend.Close() 1350 backendURL, err := url.Parse(backend.URL) 1351 if err != nil { 1352 t.Fatal(err) 1353 } 1354 proxyHandler := NewSingleHostReverseProxy(backendURL) 1355 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1356 frontend := httptest.NewServer(proxyHandler) 1357 defer frontend.Close() 1358 frontendClient := frontend.Client() 1359 1360 res, err := frontendClient.Get(frontend.URL) 1361 if err != nil { 1362 t.Fatalf("Get: %v", err) 1363 } 1364 1365 io.ReadAll(res.Body) 1366 1367 if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { 1368 t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) 1369 } 1370 1371 } 1372 1373 func TestSingleJoinSlash(t *testing.T) { 1374 tests := []struct { 1375 slasha string 1376 slashb string 1377 expected string 1378 }{ 1379 {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1380 {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1381 {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, 1382 {"https://www.google.com", "", "https://www.google.com/"}, 1383 {"", "favicon.ico", "/favicon.ico"}, 1384 } 1385 for _, tt := range tests { 1386 if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { 1387 t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", 1388 tt.slasha, 1389 tt.slashb, 1390 tt.expected, 1391 got) 1392 } 1393 } 1394 } 1395 1396 func TestJoinURLPath(t *testing.T) { 1397 tests := []struct { 1398 a *url.URL 1399 b *url.URL 1400 wantPath string 1401 wantRaw string 1402 }{ 1403 {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, 1404 {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, 1405 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1406 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1407 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, 1408 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, 1409 } 1410 1411 for _, tt := range tests { 1412 p, rp := joinURLPath(tt.a, tt.b) 1413 if p != tt.wantPath || rp != tt.wantRaw { 1414 t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", 1415 tt.a.Path, tt.a.RawPath, 1416 tt.b.Path, tt.b.RawPath, 1417 tt.wantPath, tt.wantRaw, 1418 p, rp) 1419 } 1420 } 1421 }