github.com/geraldss/go/src@v0.0.0-20210511222824-ac7d0ebfc235/net/http/httputil/reverseproxy_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Reverse proxy tests. 6 7 package httputil 8 9 import ( 10 "bufio" 11 "bytes" 12 "context" 13 "errors" 14 "fmt" 15 "io" 16 "log" 17 "net/http" 18 "net/http/httptest" 19 "net/url" 20 "os" 21 "reflect" 22 "sort" 23 "strconv" 24 "strings" 25 "sync" 26 "testing" 27 "time" 28 ) 29 30 const fakeHopHeader = "X-Fake-Hop-Header-For-Test" 31 32 func init() { 33 inOurTests = true 34 hopHeaders = append(hopHeaders, fakeHopHeader) 35 } 36 37 func TestReverseProxy(t *testing.T) { 38 const backendResponse = "I am the backend" 39 const backendStatus = 404 40 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 41 if r.Method == "GET" && r.FormValue("mode") == "hangup" { 42 c, _, _ := w.(http.Hijacker).Hijack() 43 c.Close() 44 return 45 } 46 if len(r.TransferEncoding) > 0 { 47 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) 48 } 49 if r.Header.Get("X-Forwarded-For") == "" { 50 t.Errorf("didn't get X-Forwarded-For header") 51 } 52 if c := r.Header.Get("Connection"); c != "" { 53 t.Errorf("handler got Connection header value %q", c) 54 } 55 if c := r.Header.Get("Te"); c != "trailers" { 56 t.Errorf("handler got Te header value %q; want 'trailers'", c) 57 } 58 if c := r.Header.Get("Upgrade"); c != "" { 59 t.Errorf("handler got Upgrade header value %q", c) 60 } 61 if c := r.Header.Get("Proxy-Connection"); c != "" { 62 t.Errorf("handler got Proxy-Connection header value %q", c) 63 } 64 if g, e := r.Host, "some-name"; g != e { 65 t.Errorf("backend got Host header %q, want %q", g, e) 66 } 67 w.Header().Set("Trailers", "not a special header field name") 68 w.Header().Set("Trailer", "X-Trailer") 69 w.Header().Set("X-Foo", "bar") 70 w.Header().Set("Upgrade", "foo") 71 w.Header().Set(fakeHopHeader, "foo") 72 w.Header().Add("X-Multi-Value", "foo") 73 w.Header().Add("X-Multi-Value", "bar") 74 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) 75 w.WriteHeader(backendStatus) 76 w.Write([]byte(backendResponse)) 77 w.Header().Set("X-Trailer", "trailer_value") 78 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 79 })) 80 defer backend.Close() 81 backendURL, err := url.Parse(backend.URL) 82 if err != nil { 83 t.Fatal(err) 84 } 85 proxyHandler := NewSingleHostReverseProxy(backendURL) 86 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 87 frontend := httptest.NewServer(proxyHandler) 88 defer frontend.Close() 89 frontendClient := frontend.Client() 90 91 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 92 getReq.Host = "some-name" 93 getReq.Header.Set("Connection", "close") 94 getReq.Header.Set("Te", "trailers") 95 getReq.Header.Set("Proxy-Connection", "should be deleted") 96 getReq.Header.Set("Upgrade", "foo") 97 getReq.Close = true 98 res, err := frontendClient.Do(getReq) 99 if err != nil { 100 t.Fatalf("Get: %v", err) 101 } 102 if g, e := res.StatusCode, backendStatus; g != e { 103 t.Errorf("got res.StatusCode %d; expected %d", g, e) 104 } 105 if g, e := res.Header.Get("X-Foo"), "bar"; g != e { 106 t.Errorf("got X-Foo %q; expected %q", g, e) 107 } 108 if c := res.Header.Get(fakeHopHeader); c != "" { 109 t.Errorf("got %s header value %q", fakeHopHeader, c) 110 } 111 if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { 112 t.Errorf("header Trailers = %q; want %q", g, e) 113 } 114 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { 115 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) 116 } 117 if g, e := len(res.Header["Set-Cookie"]), 1; g != e { 118 t.Fatalf("got %d SetCookies, want %d", g, e) 119 } 120 if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { 121 t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) 122 } 123 if cookie := res.Cookies()[0]; cookie.Name != "flavor" { 124 t.Errorf("unexpected cookie %q", cookie.Name) 125 } 126 bodyBytes, _ := io.ReadAll(res.Body) 127 if g, e := string(bodyBytes), backendResponse; g != e { 128 t.Errorf("got body %q; expected %q", g, e) 129 } 130 if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { 131 t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) 132 } 133 if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { 134 t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) 135 } 136 137 // Test that a backend failing to be reached or one which doesn't return 138 // a response results in a StatusBadGateway. 139 getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) 140 getReq.Close = true 141 res, err = frontendClient.Do(getReq) 142 if err != nil { 143 t.Fatal(err) 144 } 145 res.Body.Close() 146 if res.StatusCode != http.StatusBadGateway { 147 t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) 148 } 149 150 } 151 152 // Issue 16875: remove any proxied headers mentioned in the "Connection" 153 // header value. 154 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { 155 const fakeConnectionToken = "X-Fake-Connection-Token" 156 const backendResponse = "I am the backend" 157 158 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 159 // in the Request's Connection header. 160 const someConnHeader = "X-Some-Conn-Header" 161 162 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 163 if c := r.Header.Get("Connection"); c != "" { 164 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 165 } 166 if c := r.Header.Get(fakeConnectionToken); c != "" { 167 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 168 } 169 if c := r.Header.Get(someConnHeader); c != "" { 170 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 171 } 172 w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) 173 w.Header().Add("Connection", someConnHeader) 174 w.Header().Set(someConnHeader, "should be deleted") 175 w.Header().Set(fakeConnectionToken, "should be deleted") 176 io.WriteString(w, backendResponse) 177 })) 178 defer backend.Close() 179 backendURL, err := url.Parse(backend.URL) 180 if err != nil { 181 t.Fatal(err) 182 } 183 proxyHandler := NewSingleHostReverseProxy(backendURL) 184 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 185 proxyHandler.ServeHTTP(w, r) 186 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 187 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 188 } 189 if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { 190 t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") 191 } 192 c := r.Header["Connection"] 193 var cf []string 194 for _, f := range c { 195 for _, sf := range strings.Split(f, ",") { 196 if sf = strings.TrimSpace(sf); sf != "" { 197 cf = append(cf, sf) 198 } 199 } 200 } 201 sort.Strings(cf) 202 expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} 203 sort.Strings(expectedValues) 204 if !reflect.DeepEqual(cf, expectedValues) { 205 t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) 206 } 207 })) 208 defer frontend.Close() 209 210 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 211 getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) 212 getReq.Header.Add("Connection", someConnHeader) 213 getReq.Header.Set(someConnHeader, "should be deleted") 214 getReq.Header.Set(fakeConnectionToken, "should be deleted") 215 res, err := frontend.Client().Do(getReq) 216 if err != nil { 217 t.Fatalf("Get: %v", err) 218 } 219 defer res.Body.Close() 220 bodyBytes, err := io.ReadAll(res.Body) 221 if err != nil { 222 t.Fatalf("reading body: %v", err) 223 } 224 if got, want := string(bodyBytes), backendResponse; got != want { 225 t.Errorf("got body %q; want %q", got, want) 226 } 227 if c := res.Header.Get("Connection"); c != "" { 228 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 229 } 230 if c := res.Header.Get(someConnHeader); c != "" { 231 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 232 } 233 if c := res.Header.Get(fakeConnectionToken); c != "" { 234 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 235 } 236 } 237 238 func TestXForwardedFor(t *testing.T) { 239 const prevForwardedFor = "client ip" 240 const backendResponse = "I am the backend" 241 const backendStatus = 404 242 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 243 if r.Header.Get("X-Forwarded-For") == "" { 244 t.Errorf("didn't get X-Forwarded-For header") 245 } 246 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { 247 t.Errorf("X-Forwarded-For didn't contain prior data") 248 } 249 w.WriteHeader(backendStatus) 250 w.Write([]byte(backendResponse)) 251 })) 252 defer backend.Close() 253 backendURL, err := url.Parse(backend.URL) 254 if err != nil { 255 t.Fatal(err) 256 } 257 proxyHandler := NewSingleHostReverseProxy(backendURL) 258 frontend := httptest.NewServer(proxyHandler) 259 defer frontend.Close() 260 261 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 262 getReq.Host = "some-name" 263 getReq.Header.Set("Connection", "close") 264 getReq.Header.Set("X-Forwarded-For", prevForwardedFor) 265 getReq.Close = true 266 res, err := frontend.Client().Do(getReq) 267 if err != nil { 268 t.Fatalf("Get: %v", err) 269 } 270 if g, e := res.StatusCode, backendStatus; g != e { 271 t.Errorf("got res.StatusCode %d; expected %d", g, e) 272 } 273 bodyBytes, _ := io.ReadAll(res.Body) 274 if g, e := string(bodyBytes), backendResponse; g != e { 275 t.Errorf("got body %q; expected %q", g, e) 276 } 277 } 278 279 // Issue 38079: don't append to X-Forwarded-For if it's present but nil 280 func TestXForwardedFor_Omit(t *testing.T) { 281 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 282 if v := r.Header.Get("X-Forwarded-For"); v != "" { 283 t.Errorf("got X-Forwarded-For header: %q", v) 284 } 285 w.Write([]byte("hi")) 286 })) 287 defer backend.Close() 288 backendURL, err := url.Parse(backend.URL) 289 if err != nil { 290 t.Fatal(err) 291 } 292 proxyHandler := NewSingleHostReverseProxy(backendURL) 293 frontend := httptest.NewServer(proxyHandler) 294 defer frontend.Close() 295 296 oldDirector := proxyHandler.Director 297 proxyHandler.Director = func(r *http.Request) { 298 r.Header["X-Forwarded-For"] = nil 299 oldDirector(r) 300 } 301 302 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 303 getReq.Host = "some-name" 304 getReq.Close = true 305 res, err := frontend.Client().Do(getReq) 306 if err != nil { 307 t.Fatalf("Get: %v", err) 308 } 309 res.Body.Close() 310 } 311 312 var proxyQueryTests = []struct { 313 baseSuffix string // suffix to add to backend URL 314 reqSuffix string // suffix to add to frontend's request URL 315 want string // what backend should see for final request URL (without ?) 316 }{ 317 {"", "", ""}, 318 {"?sta=tic", "?us=er", "sta=tic&us=er"}, 319 {"", "?us=er", "us=er"}, 320 {"?sta=tic", "", "sta=tic"}, 321 } 322 323 func TestReverseProxyQuery(t *testing.T) { 324 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 325 w.Header().Set("X-Got-Query", r.URL.RawQuery) 326 w.Write([]byte("hi")) 327 })) 328 defer backend.Close() 329 330 for i, tt := range proxyQueryTests { 331 backendURL, err := url.Parse(backend.URL + tt.baseSuffix) 332 if err != nil { 333 t.Fatal(err) 334 } 335 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) 336 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) 337 req.Close = true 338 res, err := frontend.Client().Do(req) 339 if err != nil { 340 t.Fatalf("%d. Get: %v", i, err) 341 } 342 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { 343 t.Errorf("%d. got query %q; expected %q", i, g, e) 344 } 345 res.Body.Close() 346 frontend.Close() 347 } 348 } 349 350 func TestReverseProxyFlushInterval(t *testing.T) { 351 const expected = "hi" 352 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 353 w.Write([]byte(expected)) 354 })) 355 defer backend.Close() 356 357 backendURL, err := url.Parse(backend.URL) 358 if err != nil { 359 t.Fatal(err) 360 } 361 362 proxyHandler := NewSingleHostReverseProxy(backendURL) 363 proxyHandler.FlushInterval = time.Microsecond 364 365 frontend := httptest.NewServer(proxyHandler) 366 defer frontend.Close() 367 368 req, _ := http.NewRequest("GET", frontend.URL, nil) 369 req.Close = true 370 res, err := frontend.Client().Do(req) 371 if err != nil { 372 t.Fatalf("Get: %v", err) 373 } 374 defer res.Body.Close() 375 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { 376 t.Errorf("got body %q; expected %q", bodyBytes, expected) 377 } 378 } 379 380 func TestReverseProxyFlushIntervalHeaders(t *testing.T) { 381 const expected = "hi" 382 stopCh := make(chan struct{}) 383 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 384 w.Header().Add("MyHeader", expected) 385 w.WriteHeader(200) 386 w.(http.Flusher).Flush() 387 <-stopCh 388 })) 389 defer backend.Close() 390 defer close(stopCh) 391 392 backendURL, err := url.Parse(backend.URL) 393 if err != nil { 394 t.Fatal(err) 395 } 396 397 proxyHandler := NewSingleHostReverseProxy(backendURL) 398 proxyHandler.FlushInterval = time.Microsecond 399 400 frontend := httptest.NewServer(proxyHandler) 401 defer frontend.Close() 402 403 req, _ := http.NewRequest("GET", frontend.URL, nil) 404 req.Close = true 405 406 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) 407 defer cancel() 408 req = req.WithContext(ctx) 409 410 res, err := frontend.Client().Do(req) 411 if err != nil { 412 t.Fatalf("Get: %v", err) 413 } 414 defer res.Body.Close() 415 416 if res.Header.Get("MyHeader") != expected { 417 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) 418 } 419 } 420 421 func TestReverseProxyCancellation(t *testing.T) { 422 const backendResponse = "I am the backend" 423 424 reqInFlight := make(chan struct{}) 425 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 426 close(reqInFlight) // cause the client to cancel its request 427 428 select { 429 case <-time.After(10 * time.Second): 430 // Note: this should only happen in broken implementations, and the 431 // closenotify case should be instantaneous. 432 t.Error("Handler never saw CloseNotify") 433 return 434 case <-w.(http.CloseNotifier).CloseNotify(): 435 } 436 437 w.WriteHeader(http.StatusOK) 438 w.Write([]byte(backendResponse)) 439 })) 440 441 defer backend.Close() 442 443 backend.Config.ErrorLog = log.New(io.Discard, "", 0) 444 445 backendURL, err := url.Parse(backend.URL) 446 if err != nil { 447 t.Fatal(err) 448 } 449 450 proxyHandler := NewSingleHostReverseProxy(backendURL) 451 452 // Discards errors of the form: 453 // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection 454 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) 455 456 frontend := httptest.NewServer(proxyHandler) 457 defer frontend.Close() 458 frontendClient := frontend.Client() 459 460 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 461 go func() { 462 <-reqInFlight 463 frontendClient.Transport.(*http.Transport).CancelRequest(getReq) 464 }() 465 res, err := frontendClient.Do(getReq) 466 if res != nil { 467 t.Errorf("got response %v; want nil", res.Status) 468 } 469 if err == nil { 470 // This should be an error like: 471 // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079: 472 // use of closed network connection 473 t.Error("Server.Client().Do() returned nil error; want non-nil error") 474 } 475 } 476 477 func req(t *testing.T, v string) *http.Request { 478 req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) 479 if err != nil { 480 t.Fatal(err) 481 } 482 return req 483 } 484 485 // Issue 12344 486 func TestNilBody(t *testing.T) { 487 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 488 w.Write([]byte("hi")) 489 })) 490 defer backend.Close() 491 492 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 493 backURL, _ := url.Parse(backend.URL) 494 rp := NewSingleHostReverseProxy(backURL) 495 r := req(t, "GET / HTTP/1.0\r\n\r\n") 496 r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working 497 rp.ServeHTTP(w, r) 498 })) 499 defer frontend.Close() 500 501 res, err := http.Get(frontend.URL) 502 if err != nil { 503 t.Fatal(err) 504 } 505 defer res.Body.Close() 506 slurp, err := io.ReadAll(res.Body) 507 if err != nil { 508 t.Fatal(err) 509 } 510 if string(slurp) != "hi" { 511 t.Errorf("Got %q; want %q", slurp, "hi") 512 } 513 } 514 515 // Issue 15524 516 func TestUserAgentHeader(t *testing.T) { 517 const explicitUA = "explicit UA" 518 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 519 if r.URL.Path == "/noua" { 520 if c := r.Header.Get("User-Agent"); c != "" { 521 t.Errorf("handler got non-empty User-Agent header %q", c) 522 } 523 return 524 } 525 if c := r.Header.Get("User-Agent"); c != explicitUA { 526 t.Errorf("handler got unexpected User-Agent header %q", c) 527 } 528 })) 529 defer backend.Close() 530 backendURL, err := url.Parse(backend.URL) 531 if err != nil { 532 t.Fatal(err) 533 } 534 proxyHandler := NewSingleHostReverseProxy(backendURL) 535 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 536 frontend := httptest.NewServer(proxyHandler) 537 defer frontend.Close() 538 frontendClient := frontend.Client() 539 540 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 541 getReq.Header.Set("User-Agent", explicitUA) 542 getReq.Close = true 543 res, err := frontendClient.Do(getReq) 544 if err != nil { 545 t.Fatalf("Get: %v", err) 546 } 547 res.Body.Close() 548 549 getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) 550 getReq.Header.Set("User-Agent", "") 551 getReq.Close = true 552 res, err = frontendClient.Do(getReq) 553 if err != nil { 554 t.Fatalf("Get: %v", err) 555 } 556 res.Body.Close() 557 } 558 559 type bufferPool struct { 560 get func() []byte 561 put func([]byte) 562 } 563 564 func (bp bufferPool) Get() []byte { return bp.get() } 565 func (bp bufferPool) Put(v []byte) { bp.put(v) } 566 567 func TestReverseProxyGetPutBuffer(t *testing.T) { 568 const msg = "hi" 569 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 570 io.WriteString(w, msg) 571 })) 572 defer backend.Close() 573 574 backendURL, err := url.Parse(backend.URL) 575 if err != nil { 576 t.Fatal(err) 577 } 578 579 var ( 580 mu sync.Mutex 581 log []string 582 ) 583 addLog := func(event string) { 584 mu.Lock() 585 defer mu.Unlock() 586 log = append(log, event) 587 } 588 rp := NewSingleHostReverseProxy(backendURL) 589 const size = 1234 590 rp.BufferPool = bufferPool{ 591 get: func() []byte { 592 addLog("getBuf") 593 return make([]byte, size) 594 }, 595 put: func(p []byte) { 596 addLog("putBuf-" + strconv.Itoa(len(p))) 597 }, 598 } 599 frontend := httptest.NewServer(rp) 600 defer frontend.Close() 601 602 req, _ := http.NewRequest("GET", frontend.URL, nil) 603 req.Close = true 604 res, err := frontend.Client().Do(req) 605 if err != nil { 606 t.Fatalf("Get: %v", err) 607 } 608 slurp, err := io.ReadAll(res.Body) 609 res.Body.Close() 610 if err != nil { 611 t.Fatalf("reading body: %v", err) 612 } 613 if string(slurp) != msg { 614 t.Errorf("msg = %q; want %q", slurp, msg) 615 } 616 wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} 617 mu.Lock() 618 defer mu.Unlock() 619 if !reflect.DeepEqual(log, wantLog) { 620 t.Errorf("Log events = %q; want %q", log, wantLog) 621 } 622 } 623 624 func TestReverseProxy_Post(t *testing.T) { 625 const backendResponse = "I am the backend" 626 const backendStatus = 200 627 var requestBody = bytes.Repeat([]byte("a"), 1<<20) 628 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 629 slurp, err := io.ReadAll(r.Body) 630 if err != nil { 631 t.Errorf("Backend body read = %v", err) 632 } 633 if len(slurp) != len(requestBody) { 634 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) 635 } 636 if !bytes.Equal(slurp, requestBody) { 637 t.Error("Backend read wrong request body.") // 1MB; omitting details 638 } 639 w.Write([]byte(backendResponse)) 640 })) 641 defer backend.Close() 642 backendURL, err := url.Parse(backend.URL) 643 if err != nil { 644 t.Fatal(err) 645 } 646 proxyHandler := NewSingleHostReverseProxy(backendURL) 647 frontend := httptest.NewServer(proxyHandler) 648 defer frontend.Close() 649 650 postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) 651 res, err := frontend.Client().Do(postReq) 652 if err != nil { 653 t.Fatalf("Do: %v", err) 654 } 655 if g, e := res.StatusCode, backendStatus; g != e { 656 t.Errorf("got res.StatusCode %d; expected %d", g, e) 657 } 658 bodyBytes, _ := io.ReadAll(res.Body) 659 if g, e := string(bodyBytes), backendResponse; g != e { 660 t.Errorf("got body %q; expected %q", g, e) 661 } 662 } 663 664 type RoundTripperFunc func(*http.Request) (*http.Response, error) 665 666 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 667 return fn(req) 668 } 669 670 // Issue 16036: send a Request with a nil Body when possible 671 func TestReverseProxy_NilBody(t *testing.T) { 672 backendURL, _ := url.Parse("http://fake.tld/") 673 proxyHandler := NewSingleHostReverseProxy(backendURL) 674 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 675 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 676 if req.Body != nil { 677 t.Error("Body != nil; want a nil Body") 678 } 679 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 680 }) 681 frontend := httptest.NewServer(proxyHandler) 682 defer frontend.Close() 683 684 res, err := frontend.Client().Get(frontend.URL) 685 if err != nil { 686 t.Fatal(err) 687 } 688 defer res.Body.Close() 689 if res.StatusCode != 502 { 690 t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) 691 } 692 } 693 694 // Issue 33142: always allocate the request headers 695 func TestReverseProxy_AllocatedHeader(t *testing.T) { 696 proxyHandler := new(ReverseProxy) 697 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 698 proxyHandler.Director = func(*http.Request) {} // noop 699 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 700 if req.Header == nil { 701 t.Error("Header == nil; want a non-nil Header") 702 } 703 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 704 }) 705 706 proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ 707 Method: "GET", 708 URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, 709 Proto: "HTTP/1.0", 710 ProtoMajor: 1, 711 }) 712 } 713 714 // Issue 14237. Test ModifyResponse and that an error from it 715 // causes the proxy to return StatusBadGateway, or StatusOK otherwise. 716 func TestReverseProxyModifyResponse(t *testing.T) { 717 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 718 w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) 719 })) 720 defer backendServer.Close() 721 722 rpURL, _ := url.Parse(backendServer.URL) 723 rproxy := NewSingleHostReverseProxy(rpURL) 724 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 725 rproxy.ModifyResponse = func(resp *http.Response) error { 726 if resp.Header.Get("X-Hit-Mod") != "true" { 727 return fmt.Errorf("tried to by-pass proxy") 728 } 729 return nil 730 } 731 732 frontendProxy := httptest.NewServer(rproxy) 733 defer frontendProxy.Close() 734 735 tests := []struct { 736 url string 737 wantCode int 738 }{ 739 {frontendProxy.URL + "/mod", http.StatusOK}, 740 {frontendProxy.URL + "/schedule", http.StatusBadGateway}, 741 } 742 743 for i, tt := range tests { 744 resp, err := http.Get(tt.url) 745 if err != nil { 746 t.Fatalf("failed to reach proxy: %v", err) 747 } 748 if g, e := resp.StatusCode, tt.wantCode; g != e { 749 t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) 750 } 751 resp.Body.Close() 752 } 753 } 754 755 type failingRoundTripper struct{} 756 757 func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 758 return nil, errors.New("some error") 759 } 760 761 type staticResponseRoundTripper struct{ res *http.Response } 762 763 func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 764 return rt.res, nil 765 } 766 767 func TestReverseProxyErrorHandler(t *testing.T) { 768 tests := []struct { 769 name string 770 wantCode int 771 errorHandler func(http.ResponseWriter, *http.Request, error) 772 transport http.RoundTripper // defaults to failingRoundTripper 773 modifyResponse func(*http.Response) error 774 }{ 775 { 776 name: "default", 777 wantCode: http.StatusBadGateway, 778 }, 779 { 780 name: "errorhandler", 781 wantCode: http.StatusTeapot, 782 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 783 }, 784 { 785 name: "modifyresponse_noerr", 786 transport: staticResponseRoundTripper{ 787 &http.Response{StatusCode: 345, Body: http.NoBody}, 788 }, 789 modifyResponse: func(res *http.Response) error { 790 res.StatusCode++ 791 return nil 792 }, 793 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 794 wantCode: 346, 795 }, 796 { 797 name: "modifyresponse_err", 798 transport: staticResponseRoundTripper{ 799 &http.Response{StatusCode: 345, Body: http.NoBody}, 800 }, 801 modifyResponse: func(res *http.Response) error { 802 res.StatusCode++ 803 return errors.New("some error to trigger errorHandler") 804 }, 805 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 806 wantCode: http.StatusTeapot, 807 }, 808 } 809 810 for _, tt := range tests { 811 t.Run(tt.name, func(t *testing.T) { 812 target := &url.URL{ 813 Scheme: "http", 814 Host: "dummy.tld", 815 Path: "/", 816 } 817 rproxy := NewSingleHostReverseProxy(target) 818 rproxy.Transport = tt.transport 819 rproxy.ModifyResponse = tt.modifyResponse 820 if rproxy.Transport == nil { 821 rproxy.Transport = failingRoundTripper{} 822 } 823 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 824 if tt.errorHandler != nil { 825 rproxy.ErrorHandler = tt.errorHandler 826 } 827 frontendProxy := httptest.NewServer(rproxy) 828 defer frontendProxy.Close() 829 830 resp, err := http.Get(frontendProxy.URL + "/test") 831 if err != nil { 832 t.Fatalf("failed to reach proxy: %v", err) 833 } 834 if g, e := resp.StatusCode, tt.wantCode; g != e { 835 t.Errorf("got res.StatusCode %d; expected %d", g, e) 836 } 837 resp.Body.Close() 838 }) 839 } 840 } 841 842 // Issue 16659: log errors from short read 843 func TestReverseProxy_CopyBuffer(t *testing.T) { 844 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 845 out := "this call was relayed by the reverse proxy" 846 // Coerce a wrong content length to induce io.UnexpectedEOF 847 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 848 fmt.Fprintln(w, out) 849 })) 850 defer backendServer.Close() 851 852 rpURL, err := url.Parse(backendServer.URL) 853 if err != nil { 854 t.Fatal(err) 855 } 856 857 var proxyLog bytes.Buffer 858 rproxy := NewSingleHostReverseProxy(rpURL) 859 rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) 860 donec := make(chan bool, 1) 861 frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 862 defer func() { donec <- true }() 863 rproxy.ServeHTTP(w, r) 864 })) 865 defer frontendProxy.Close() 866 867 if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { 868 t.Fatalf("want non-nil error") 869 } 870 // The race detector complains about the proxyLog usage in logf in copyBuffer 871 // and our usage below with proxyLog.Bytes() so we're explicitly using a 872 // channel to ensure that the ReverseProxy's ServeHTTP is done before we 873 // continue after Get. 874 <-donec 875 876 expected := []string{ 877 "EOF", 878 "read", 879 } 880 for _, phrase := range expected { 881 if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { 882 t.Errorf("expected log to contain phrase %q", phrase) 883 } 884 } 885 } 886 887 type staticTransport struct { 888 res *http.Response 889 } 890 891 func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { 892 return t.res, nil 893 } 894 895 func BenchmarkServeHTTP(b *testing.B) { 896 res := &http.Response{ 897 StatusCode: 200, 898 Body: io.NopCloser(strings.NewReader("")), 899 } 900 proxy := &ReverseProxy{ 901 Director: func(*http.Request) {}, 902 Transport: &staticTransport{res}, 903 } 904 905 w := httptest.NewRecorder() 906 r := httptest.NewRequest("GET", "/", nil) 907 908 b.ReportAllocs() 909 for i := 0; i < b.N; i++ { 910 proxy.ServeHTTP(w, r) 911 } 912 } 913 914 func TestServeHTTPDeepCopy(t *testing.T) { 915 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 916 w.Write([]byte("Hello Gopher!")) 917 })) 918 defer backend.Close() 919 backendURL, err := url.Parse(backend.URL) 920 if err != nil { 921 t.Fatal(err) 922 } 923 924 type result struct { 925 before, after string 926 } 927 928 resultChan := make(chan result, 1) 929 proxyHandler := NewSingleHostReverseProxy(backendURL) 930 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 931 before := r.URL.String() 932 proxyHandler.ServeHTTP(w, r) 933 after := r.URL.String() 934 resultChan <- result{before: before, after: after} 935 })) 936 defer frontend.Close() 937 938 want := result{before: "/", after: "/"} 939 940 res, err := frontend.Client().Get(frontend.URL) 941 if err != nil { 942 t.Fatalf("Do: %v", err) 943 } 944 res.Body.Close() 945 946 got := <-resultChan 947 if got != want { 948 t.Errorf("got = %+v; want = %+v", got, want) 949 } 950 } 951 952 // Issue 18327: verify we always do a deep copy of the Request.Header map 953 // before any mutations. 954 func TestClonesRequestHeaders(t *testing.T) { 955 log.SetOutput(io.Discard) 956 defer log.SetOutput(os.Stderr) 957 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 958 req.RemoteAddr = "1.2.3.4:56789" 959 rp := &ReverseProxy{ 960 Director: func(req *http.Request) { 961 req.Header.Set("From-Director", "1") 962 }, 963 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { 964 if v := req.Header.Get("From-Director"); v != "1" { 965 t.Errorf("From-Directory value = %q; want 1", v) 966 } 967 return nil, io.EOF 968 }), 969 } 970 rp.ServeHTTP(httptest.NewRecorder(), req) 971 972 if req.Header.Get("From-Director") == "1" { 973 t.Error("Director header mutation modified caller's request") 974 } 975 if req.Header.Get("X-Forwarded-For") != "" { 976 t.Error("X-Forward-For header mutation modified caller's request") 977 } 978 979 } 980 981 type roundTripperFunc func(req *http.Request) (*http.Response, error) 982 983 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 984 return fn(req) 985 } 986 987 func TestModifyResponseClosesBody(t *testing.T) { 988 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 989 req.RemoteAddr = "1.2.3.4:56789" 990 closeCheck := new(checkCloser) 991 logBuf := new(bytes.Buffer) 992 outErr := errors.New("ModifyResponse error") 993 rp := &ReverseProxy{ 994 Director: func(req *http.Request) {}, 995 Transport: &staticTransport{&http.Response{ 996 StatusCode: 200, 997 Body: closeCheck, 998 }}, 999 ErrorLog: log.New(logBuf, "", 0), 1000 ModifyResponse: func(*http.Response) error { 1001 return outErr 1002 }, 1003 } 1004 rec := httptest.NewRecorder() 1005 rp.ServeHTTP(rec, req) 1006 res := rec.Result() 1007 if g, e := res.StatusCode, http.StatusBadGateway; g != e { 1008 t.Errorf("got res.StatusCode %d; expected %d", g, e) 1009 } 1010 if !closeCheck.closed { 1011 t.Errorf("body should have been closed") 1012 } 1013 if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { 1014 t.Errorf("ErrorLog %q does not contain %q", g, e) 1015 } 1016 } 1017 1018 type checkCloser struct { 1019 closed bool 1020 } 1021 1022 func (cc *checkCloser) Close() error { 1023 cc.closed = true 1024 return nil 1025 } 1026 1027 func (cc *checkCloser) Read(b []byte) (int, error) { 1028 return len(b), nil 1029 } 1030 1031 // Issue 23643: panic on body copy error 1032 func TestReverseProxy_PanicBodyError(t *testing.T) { 1033 log.SetOutput(io.Discard) 1034 defer log.SetOutput(os.Stderr) 1035 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1036 out := "this call was relayed by the reverse proxy" 1037 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1038 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1039 fmt.Fprintln(w, out) 1040 })) 1041 defer backendServer.Close() 1042 1043 rpURL, err := url.Parse(backendServer.URL) 1044 if err != nil { 1045 t.Fatal(err) 1046 } 1047 1048 rproxy := NewSingleHostReverseProxy(rpURL) 1049 1050 // Ensure that the handler panics when the body read encounters an 1051 // io.ErrUnexpectedEOF 1052 defer func() { 1053 err := recover() 1054 if err == nil { 1055 t.Fatal("handler should have panicked") 1056 } 1057 if err != http.ErrAbortHandler { 1058 t.Fatal("expected ErrAbortHandler, got", err) 1059 } 1060 }() 1061 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1062 rproxy.ServeHTTP(httptest.NewRecorder(), req) 1063 } 1064 1065 func TestSelectFlushInterval(t *testing.T) { 1066 tests := []struct { 1067 name string 1068 p *ReverseProxy 1069 res *http.Response 1070 want time.Duration 1071 }{ 1072 { 1073 name: "default", 1074 res: &http.Response{}, 1075 p: &ReverseProxy{FlushInterval: 123}, 1076 want: 123, 1077 }, 1078 { 1079 name: "server-sent events overrides non-zero", 1080 res: &http.Response{ 1081 Header: http.Header{ 1082 "Content-Type": {"text/event-stream"}, 1083 }, 1084 }, 1085 p: &ReverseProxy{FlushInterval: 123}, 1086 want: -1, 1087 }, 1088 { 1089 name: "server-sent events overrides zero", 1090 res: &http.Response{ 1091 Header: http.Header{ 1092 "Content-Type": {"text/event-stream"}, 1093 }, 1094 }, 1095 p: &ReverseProxy{FlushInterval: 0}, 1096 want: -1, 1097 }, 1098 { 1099 name: "Content-Length: -1, overrides non-zero", 1100 res: &http.Response{ 1101 ContentLength: -1, 1102 }, 1103 p: &ReverseProxy{FlushInterval: 123}, 1104 want: -1, 1105 }, 1106 { 1107 name: "Content-Length: -1, overrides zero", 1108 res: &http.Response{ 1109 ContentLength: -1, 1110 }, 1111 p: &ReverseProxy{FlushInterval: 0}, 1112 want: -1, 1113 }, 1114 } 1115 for _, tt := range tests { 1116 t.Run(tt.name, func(t *testing.T) { 1117 got := tt.p.flushInterval(tt.res) 1118 if got != tt.want { 1119 t.Errorf("flushLatency = %v; want %v", got, tt.want) 1120 } 1121 }) 1122 } 1123 } 1124 1125 func TestReverseProxyWebSocket(t *testing.T) { 1126 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1127 if upgradeType(r.Header) != "websocket" { 1128 t.Error("unexpected backend request") 1129 http.Error(w, "unexpected request", 400) 1130 return 1131 } 1132 c, _, err := w.(http.Hijacker).Hijack() 1133 if err != nil { 1134 t.Error(err) 1135 return 1136 } 1137 defer c.Close() 1138 io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") 1139 bs := bufio.NewScanner(c) 1140 if !bs.Scan() { 1141 t.Errorf("backend failed to read line from client: %v", bs.Err()) 1142 return 1143 } 1144 fmt.Fprintf(c, "backend got %q\n", bs.Text()) 1145 })) 1146 defer backendServer.Close() 1147 1148 backURL, _ := url.Parse(backendServer.URL) 1149 rproxy := NewSingleHostReverseProxy(backURL) 1150 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1151 rproxy.ModifyResponse = func(res *http.Response) error { 1152 res.Header.Add("X-Modified", "true") 1153 return nil 1154 } 1155 1156 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1157 rw.Header().Set("X-Header", "X-Value") 1158 rproxy.ServeHTTP(rw, req) 1159 if got, want := rw.Header().Get("X-Modified"), "true"; got != want { 1160 t.Errorf("response writer X-Modified header = %q; want %q", got, want) 1161 } 1162 }) 1163 1164 frontendProxy := httptest.NewServer(handler) 1165 defer frontendProxy.Close() 1166 1167 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1168 req.Header.Set("Connection", "Upgrade") 1169 req.Header.Set("Upgrade", "websocket") 1170 1171 c := frontendProxy.Client() 1172 res, err := c.Do(req) 1173 if err != nil { 1174 t.Fatal(err) 1175 } 1176 if res.StatusCode != 101 { 1177 t.Fatalf("status = %v; want 101", res.Status) 1178 } 1179 1180 got := res.Header.Get("X-Header") 1181 want := "X-Value" 1182 if got != want { 1183 t.Errorf("Header(XHeader) = %q; want %q", got, want) 1184 } 1185 1186 if upgradeType(res.Header) != "websocket" { 1187 t.Fatalf("not websocket upgrade; got %#v", res.Header) 1188 } 1189 rwc, ok := res.Body.(io.ReadWriteCloser) 1190 if !ok { 1191 t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) 1192 } 1193 defer rwc.Close() 1194 1195 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1196 t.Errorf("response X-Modified header = %q; want %q", got, want) 1197 } 1198 1199 io.WriteString(rwc, "Hello\n") 1200 bs := bufio.NewScanner(rwc) 1201 if !bs.Scan() { 1202 t.Fatalf("Scan: %v", bs.Err()) 1203 } 1204 got = bs.Text() 1205 want = `backend got "Hello"` 1206 if got != want { 1207 t.Errorf("got %#q, want %#q", got, want) 1208 } 1209 } 1210 1211 func TestReverseProxyWebSocketCancelation(t *testing.T) { 1212 n := 5 1213 triggerCancelCh := make(chan bool, n) 1214 nthResponse := func(i int) string { 1215 return fmt.Sprintf("backend response #%d\n", i) 1216 } 1217 terminalMsg := "final message" 1218 1219 cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1220 if g, ws := upgradeType(r.Header), "websocket"; g != ws { 1221 t.Errorf("Unexpected upgrade type %q, want %q", g, ws) 1222 http.Error(w, "Unexpected request", 400) 1223 return 1224 } 1225 conn, bufrw, err := w.(http.Hijacker).Hijack() 1226 if err != nil { 1227 t.Error(err) 1228 return 1229 } 1230 defer conn.Close() 1231 1232 upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" 1233 if _, err := io.WriteString(conn, upgradeMsg); err != nil { 1234 t.Error(err) 1235 return 1236 } 1237 if _, _, err := bufrw.ReadLine(); err != nil { 1238 t.Errorf("Failed to read line from client: %v", err) 1239 return 1240 } 1241 1242 for i := 0; i < n; i++ { 1243 if _, err := bufrw.WriteString(nthResponse(i)); err != nil { 1244 select { 1245 case <-triggerCancelCh: 1246 default: 1247 t.Errorf("Writing response #%d failed: %v", i, err) 1248 } 1249 return 1250 } 1251 bufrw.Flush() 1252 time.Sleep(time.Second) 1253 } 1254 if _, err := bufrw.WriteString(terminalMsg); err != nil { 1255 select { 1256 case <-triggerCancelCh: 1257 default: 1258 t.Errorf("Failed to write terminal message: %v", err) 1259 } 1260 } 1261 bufrw.Flush() 1262 })) 1263 defer cst.Close() 1264 1265 backendURL, _ := url.Parse(cst.URL) 1266 rproxy := NewSingleHostReverseProxy(backendURL) 1267 rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1268 rproxy.ModifyResponse = func(res *http.Response) error { 1269 res.Header.Add("X-Modified", "true") 1270 return nil 1271 } 1272 1273 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1274 rw.Header().Set("X-Header", "X-Value") 1275 ctx, cancel := context.WithCancel(req.Context()) 1276 go func() { 1277 <-triggerCancelCh 1278 cancel() 1279 }() 1280 rproxy.ServeHTTP(rw, req.WithContext(ctx)) 1281 }) 1282 1283 frontendProxy := httptest.NewServer(handler) 1284 defer frontendProxy.Close() 1285 1286 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1287 req.Header.Set("Connection", "Upgrade") 1288 req.Header.Set("Upgrade", "websocket") 1289 1290 res, err := frontendProxy.Client().Do(req) 1291 if err != nil { 1292 t.Fatalf("Dialing to frontend proxy: %v", err) 1293 } 1294 defer res.Body.Close() 1295 if g, w := res.StatusCode, 101; g != w { 1296 t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) 1297 } 1298 1299 if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { 1300 t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) 1301 } 1302 1303 if g, w := upgradeType(res.Header), "websocket"; g != w { 1304 t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) 1305 } 1306 1307 rwc, ok := res.Body.(io.ReadWriteCloser) 1308 if !ok { 1309 t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) 1310 } 1311 1312 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1313 t.Errorf("response X-Modified header = %q; want %q", got, want) 1314 } 1315 1316 if _, err := io.WriteString(rwc, "Hello\n"); err != nil { 1317 t.Fatalf("Failed to write first message: %v", err) 1318 } 1319 1320 // Read loop. 1321 1322 br := bufio.NewReader(rwc) 1323 for { 1324 line, err := br.ReadString('\n') 1325 switch { 1326 case line == terminalMsg: // this case before "err == io.EOF" 1327 t.Fatalf("The websocket request was not canceled, unfortunately!") 1328 1329 case err == io.EOF: 1330 return 1331 1332 case err != nil: 1333 t.Fatalf("Unexpected error: %v", err) 1334 1335 case line == nthResponse(0): // We've gotten the first response back 1336 // Let's trigger a cancel. 1337 close(triggerCancelCh) 1338 } 1339 } 1340 } 1341 1342 func TestUnannouncedTrailer(t *testing.T) { 1343 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1344 w.WriteHeader(http.StatusOK) 1345 w.(http.Flusher).Flush() 1346 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 1347 })) 1348 defer backend.Close() 1349 backendURL, err := url.Parse(backend.URL) 1350 if err != nil { 1351 t.Fatal(err) 1352 } 1353 proxyHandler := NewSingleHostReverseProxy(backendURL) 1354 proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests 1355 frontend := httptest.NewServer(proxyHandler) 1356 defer frontend.Close() 1357 frontendClient := frontend.Client() 1358 1359 res, err := frontendClient.Get(frontend.URL) 1360 if err != nil { 1361 t.Fatalf("Get: %v", err) 1362 } 1363 1364 io.ReadAll(res.Body) 1365 1366 if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { 1367 t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) 1368 } 1369 1370 } 1371 1372 func TestSingleJoinSlash(t *testing.T) { 1373 tests := []struct { 1374 slasha string 1375 slashb string 1376 expected string 1377 }{ 1378 {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1379 {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1380 {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, 1381 {"https://www.google.com", "", "https://www.google.com/"}, 1382 {"", "favicon.ico", "/favicon.ico"}, 1383 } 1384 for _, tt := range tests { 1385 if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { 1386 t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", 1387 tt.slasha, 1388 tt.slashb, 1389 tt.expected, 1390 got) 1391 } 1392 } 1393 } 1394 1395 func TestJoinURLPath(t *testing.T) { 1396 tests := []struct { 1397 a *url.URL 1398 b *url.URL 1399 wantPath string 1400 wantRaw string 1401 }{ 1402 {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, 1403 {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, 1404 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1405 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, 1406 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, 1407 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, 1408 } 1409 1410 for _, tt := range tests { 1411 p, rp := joinURLPath(tt.a, tt.b) 1412 if p != tt.wantPath || rp != tt.wantRaw { 1413 t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", 1414 tt.a.Path, tt.a.RawPath, 1415 tt.b.Path, tt.b.RawPath, 1416 tt.wantPath, tt.wantRaw, 1417 p, rp) 1418 } 1419 } 1420 }