github.com/ice-blockchain/go/src@v0.0.0-20240403114104-1564d284e521/net/http/transport_test.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Tests for transport.go. 6 // 7 // More tests are in clientserver_test.go (for things testing both client & server for both 8 // HTTP/1 and HTTP/2). This 9 10 package http_test 11 12 import ( 13 "bufio" 14 "bytes" 15 "compress/gzip" 16 "context" 17 "crypto/rand" 18 "crypto/tls" 19 "crypto/x509" 20 "encoding/binary" 21 "errors" 22 "fmt" 23 "go/token" 24 "internal/nettrace" 25 "io" 26 "log" 27 mrand "math/rand" 28 "net" 29 . "net/http" 30 "net/http/httptest" 31 "net/http/httptrace" 32 "net/http/httputil" 33 "net/http/internal/testcert" 34 "net/textproto" 35 "net/url" 36 "os" 37 "reflect" 38 "runtime" 39 "strconv" 40 "strings" 41 "sync" 42 "sync/atomic" 43 "testing" 44 "testing/iotest" 45 "time" 46 47 "golang.org/x/net/http/httpguts" 48 ) 49 50 // TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close 51 // and then verify that the final 2 responses get errors back. 52 53 // hostPortHandler writes back the client's "host:port". 54 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { 55 if r.FormValue("close") == "true" { 56 w.Header().Set("Connection", "close") 57 } 58 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close)) 59 w.Write([]byte(r.RemoteAddr)) 60 61 // Include the address of the net.Conn in addition to the RemoteAddr, 62 // in case kernels reuse source ports quickly (see Issue 52450) 63 if c, ok := ResponseWriterConnForTesting(w); ok { 64 fmt.Fprintf(w, ", %T %p", c, c) 65 } 66 }) 67 68 // testCloseConn is a net.Conn tracked by a testConnSet. 69 type testCloseConn struct { 70 net.Conn 71 set *testConnSet 72 } 73 74 func (c *testCloseConn) Close() error { 75 c.set.remove(c) 76 return c.Conn.Close() 77 } 78 79 // testConnSet tracks a set of TCP connections and whether they've 80 // been closed. 81 type testConnSet struct { 82 t *testing.T 83 mu sync.Mutex // guards closed and list 84 closed map[net.Conn]bool 85 list []net.Conn // in order created 86 } 87 88 func (tcs *testConnSet) insert(c net.Conn) { 89 tcs.mu.Lock() 90 defer tcs.mu.Unlock() 91 tcs.closed[c] = false 92 tcs.list = append(tcs.list, c) 93 } 94 95 func (tcs *testConnSet) remove(c net.Conn) { 96 tcs.mu.Lock() 97 defer tcs.mu.Unlock() 98 tcs.closed[c] = true 99 } 100 101 // some tests use this to manage raw tcp connections for later inspection 102 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) { 103 connSet := &testConnSet{ 104 t: t, 105 closed: make(map[net.Conn]bool), 106 } 107 dial := func(n, addr string) (net.Conn, error) { 108 c, err := net.Dial(n, addr) 109 if err != nil { 110 return nil, err 111 } 112 tc := &testCloseConn{c, connSet} 113 connSet.insert(tc) 114 return tc, nil 115 } 116 return connSet, dial 117 } 118 119 func (tcs *testConnSet) check(t *testing.T) { 120 tcs.mu.Lock() 121 defer tcs.mu.Unlock() 122 for i := 4; i >= 0; i-- { 123 for i, c := range tcs.list { 124 if tcs.closed[c] { 125 continue 126 } 127 if i != 0 { 128 // TODO(bcmills): What is the Sleep here doing, and why is this 129 // Unlock/Sleep/Lock cycle needed at all? 130 tcs.mu.Unlock() 131 time.Sleep(50 * time.Millisecond) 132 tcs.mu.Lock() 133 continue 134 } 135 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) 136 } 137 } 138 } 139 140 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) } 141 func testReuseRequest(t *testing.T, mode testMode) { 142 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 143 w.Write([]byte("{}")) 144 })).ts 145 146 c := ts.Client() 147 req, _ := NewRequest("GET", ts.URL, nil) 148 res, err := c.Do(req) 149 if err != nil { 150 t.Fatal(err) 151 } 152 err = res.Body.Close() 153 if err != nil { 154 t.Fatal(err) 155 } 156 157 res, err = c.Do(req) 158 if err != nil { 159 t.Fatal(err) 160 } 161 err = res.Body.Close() 162 if err != nil { 163 t.Fatal(err) 164 } 165 } 166 167 // Two subsequent requests and verify their response is the same. 168 // The response from the server is our own IP:port 169 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) } 170 func testTransportKeepAlives(t *testing.T, mode testMode) { 171 ts := newClientServerTest(t, mode, hostPortHandler).ts 172 173 c := ts.Client() 174 for _, disableKeepAlive := range []bool{false, true} { 175 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive 176 fetch := func(n int) string { 177 res, err := c.Get(ts.URL) 178 if err != nil { 179 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) 180 } 181 body, err := io.ReadAll(res.Body) 182 if err != nil { 183 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) 184 } 185 return string(body) 186 } 187 188 body1 := fetch(1) 189 body2 := fetch(2) 190 191 bodiesDiffer := body1 != body2 192 if bodiesDiffer != disableKeepAlive { 193 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", 194 disableKeepAlive, bodiesDiffer, body1, body2) 195 } 196 } 197 } 198 199 func TestTransportConnectionCloseOnResponse(t *testing.T) { 200 run(t, testTransportConnectionCloseOnResponse) 201 } 202 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) { 203 ts := newClientServerTest(t, mode, hostPortHandler).ts 204 205 connSet, testDial := makeTestDial(t) 206 207 c := ts.Client() 208 tr := c.Transport.(*Transport) 209 tr.Dial = testDial 210 211 for _, connectionClose := range []bool{false, true} { 212 fetch := func(n int) string { 213 req := new(Request) 214 var err error 215 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose)) 216 if err != nil { 217 t.Fatalf("URL parse error: %v", err) 218 } 219 req.Method = "GET" 220 req.Proto = "HTTP/1.1" 221 req.ProtoMajor = 1 222 req.ProtoMinor = 1 223 224 res, err := c.Do(req) 225 if err != nil { 226 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) 227 } 228 defer res.Body.Close() 229 body, err := io.ReadAll(res.Body) 230 if err != nil { 231 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) 232 } 233 return string(body) 234 } 235 236 body1 := fetch(1) 237 body2 := fetch(2) 238 bodiesDiffer := body1 != body2 239 if bodiesDiffer != connectionClose { 240 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", 241 connectionClose, bodiesDiffer, body1, body2) 242 } 243 244 tr.CloseIdleConnections() 245 } 246 247 connSet.check(t) 248 } 249 250 // TestTransportConnectionCloseOnRequest tests that the Transport's doesn't reuse 251 // an underlying TCP connection after making an http.Request with Request.Close set. 252 // 253 // It tests the behavior by making an HTTP request to a server which 254 // describes the source connection it got (remote port number + 255 // address of its net.Conn). 256 func TestTransportConnectionCloseOnRequest(t *testing.T) { 257 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode}) 258 } 259 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) { 260 ts := newClientServerTest(t, mode, hostPortHandler).ts 261 262 connSet, testDial := makeTestDial(t) 263 264 c := ts.Client() 265 tr := c.Transport.(*Transport) 266 tr.Dial = testDial 267 for _, reqClose := range []bool{false, true} { 268 fetch := func(n int) string { 269 req := new(Request) 270 var err error 271 req.URL, err = url.Parse(ts.URL) 272 if err != nil { 273 t.Fatalf("URL parse error: %v", err) 274 } 275 req.Method = "GET" 276 req.Proto = "HTTP/1.1" 277 req.ProtoMajor = 1 278 req.ProtoMinor = 1 279 req.Close = reqClose 280 281 res, err := c.Do(req) 282 if err != nil { 283 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err) 284 } 285 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want { 286 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v", 287 reqClose, got, !reqClose) 288 } 289 body, err := io.ReadAll(res.Body) 290 if err != nil { 291 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err) 292 } 293 return string(body) 294 } 295 296 body1 := fetch(1) 297 body2 := fetch(2) 298 299 got := 1 300 if body1 != body2 { 301 got++ 302 } 303 want := 1 304 if reqClose { 305 want = 2 306 } 307 if got != want { 308 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q", 309 reqClose, got, want, body1, body2) 310 } 311 312 tr.CloseIdleConnections() 313 } 314 315 connSet.check(t) 316 } 317 318 // if the Transport's DisableKeepAlives is set, all requests should 319 // send Connection: close. 320 // HTTP/1-only (Connection: close doesn't exist in h2) 321 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { 322 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode}) 323 } 324 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) { 325 ts := newClientServerTest(t, mode, hostPortHandler).ts 326 327 c := ts.Client() 328 c.Transport.(*Transport).DisableKeepAlives = true 329 330 res, err := c.Get(ts.URL) 331 if err != nil { 332 t.Fatal(err) 333 } 334 res.Body.Close() 335 if res.Header.Get("X-Saw-Close") != "true" { 336 t.Errorf("handler didn't see Connection: close ") 337 } 338 } 339 340 // Test that Transport only sends one "Connection: close", regardless of 341 // how "close" was indicated. 342 func TestTransportRespectRequestWantsClose(t *testing.T) { 343 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode}) 344 } 345 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) { 346 tests := []struct { 347 disableKeepAlives bool 348 close bool 349 }{ 350 {disableKeepAlives: false, close: false}, 351 {disableKeepAlives: false, close: true}, 352 {disableKeepAlives: true, close: false}, 353 {disableKeepAlives: true, close: true}, 354 } 355 356 for _, tc := range tests { 357 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close), 358 func(t *testing.T) { 359 ts := newClientServerTest(t, mode, hostPortHandler).ts 360 361 c := ts.Client() 362 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives 363 req, err := NewRequest("GET", ts.URL, nil) 364 if err != nil { 365 t.Fatal(err) 366 } 367 count := 0 368 trace := &httptrace.ClientTrace{ 369 WroteHeaderField: func(key string, field []string) { 370 if key != "Connection" { 371 return 372 } 373 if httpguts.HeaderValuesContainsToken(field, "close") { 374 count += 1 375 } 376 }, 377 } 378 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) 379 req.Close = tc.close 380 res, err := c.Do(req) 381 if err != nil { 382 t.Fatal(err) 383 } 384 defer res.Body.Close() 385 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want { 386 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count) 387 } 388 }) 389 } 390 391 } 392 393 func TestTransportIdleCacheKeys(t *testing.T) { 394 run(t, testTransportIdleCacheKeys, []testMode{http1Mode}) 395 } 396 func testTransportIdleCacheKeys(t *testing.T, mode testMode) { 397 ts := newClientServerTest(t, mode, hostPortHandler).ts 398 c := ts.Client() 399 tr := c.Transport.(*Transport) 400 401 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { 402 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) 403 } 404 405 resp, err := c.Get(ts.URL) 406 if err != nil { 407 t.Error(err) 408 } 409 io.ReadAll(resp.Body) 410 411 keys := tr.IdleConnKeysForTesting() 412 if e, g := 1, len(keys); e != g { 413 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g) 414 } 415 416 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e { 417 t.Errorf("Expected idle cache key %q; got %q", e, keys[0]) 418 } 419 420 tr.CloseIdleConnections() 421 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { 422 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) 423 } 424 } 425 426 // Tests that the HTTP transport re-uses connections when a client 427 // reads to the end of a response Body without closing it. 428 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) } 429 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) { 430 const msg = "foobar" 431 432 var addrSeen map[string]int 433 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 434 addrSeen[r.RemoteAddr]++ 435 if r.URL.Path == "/chunked/" { 436 w.WriteHeader(200) 437 w.(Flusher).Flush() 438 } else { 439 w.Header().Set("Content-Length", strconv.Itoa(len(msg))) 440 w.WriteHeader(200) 441 } 442 w.Write([]byte(msg)) 443 })).ts 444 445 for pi, path := range []string{"/content-length/", "/chunked/"} { 446 wantLen := []int{len(msg), -1}[pi] 447 addrSeen = make(map[string]int) 448 for i := 0; i < 3; i++ { 449 res, err := ts.Client().Get(ts.URL + path) 450 if err != nil { 451 t.Errorf("Get %s: %v", path, err) 452 continue 453 } 454 // We want to close this body eventually (before the 455 // defer afterTest at top runs), but not before the 456 // len(addrSeen) check at the bottom of this test, 457 // since Closing this early in the loop would risk 458 // making connections be re-used for the wrong reason. 459 defer res.Body.Close() 460 461 if res.ContentLength != int64(wantLen) { 462 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) 463 } 464 got, err := io.ReadAll(res.Body) 465 if string(got) != msg || err != nil { 466 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg) 467 } 468 } 469 if len(addrSeen) != 1 { 470 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen)) 471 } 472 } 473 } 474 475 func TestTransportMaxPerHostIdleConns(t *testing.T) { 476 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode}) 477 } 478 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) { 479 stop := make(chan struct{}) // stop marks the exit of main Test goroutine 480 defer close(stop) 481 482 resch := make(chan string) 483 gotReq := make(chan bool) 484 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 485 gotReq <- true 486 var msg string 487 select { 488 case <-stop: 489 return 490 case msg = <-resch: 491 } 492 _, err := w.Write([]byte(msg)) 493 if err != nil { 494 t.Errorf("Write: %v", err) 495 return 496 } 497 })).ts 498 499 c := ts.Client() 500 tr := c.Transport.(*Transport) 501 maxIdleConnsPerHost := 2 502 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost 503 504 // Start 3 outstanding requests and wait for the server to get them. 505 // Their responses will hang until we write to resch, though. 506 donech := make(chan bool) 507 doReq := func() { 508 defer func() { 509 select { 510 case <-stop: 511 return 512 case donech <- t.Failed(): 513 } 514 }() 515 resp, err := c.Get(ts.URL) 516 if err != nil { 517 t.Error(err) 518 return 519 } 520 if _, err := io.ReadAll(resp.Body); err != nil { 521 t.Errorf("ReadAll: %v", err) 522 return 523 } 524 } 525 go doReq() 526 <-gotReq 527 go doReq() 528 <-gotReq 529 go doReq() 530 <-gotReq 531 532 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { 533 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) 534 } 535 536 resch <- "res1" 537 <-donech 538 keys := tr.IdleConnKeysForTesting() 539 if e, g := 1, len(keys); e != g { 540 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g) 541 } 542 addr := ts.Listener.Addr().String() 543 cacheKey := "|http|" + addr 544 if keys[0] != cacheKey { 545 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0]) 546 } 547 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g { 548 t.Errorf("after first response, expected %d idle conns; got %d", e, g) 549 } 550 551 resch <- "res2" 552 <-donech 553 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w { 554 t.Errorf("after second response, idle conns = %d; want %d", g, w) 555 } 556 557 resch <- "res3" 558 <-donech 559 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w { 560 t.Errorf("after third response, idle conns = %d; want %d", g, w) 561 } 562 } 563 564 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { 565 run(t, testTransportMaxConnsPerHostIncludeDialInProgress) 566 } 567 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) { 568 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 569 _, err := w.Write([]byte("foo")) 570 if err != nil { 571 t.Fatalf("Write: %v", err) 572 } 573 })).ts 574 c := ts.Client() 575 tr := c.Transport.(*Transport) 576 dialStarted := make(chan struct{}) 577 stallDial := make(chan struct{}) 578 tr.Dial = func(network, addr string) (net.Conn, error) { 579 dialStarted <- struct{}{} 580 <-stallDial 581 return net.Dial(network, addr) 582 } 583 584 tr.DisableKeepAlives = true 585 tr.MaxConnsPerHost = 1 586 587 preDial := make(chan struct{}) 588 reqComplete := make(chan struct{}) 589 doReq := func(reqId string) { 590 req, _ := NewRequest("GET", ts.URL, nil) 591 trace := &httptrace.ClientTrace{ 592 GetConn: func(hostPort string) { 593 preDial <- struct{}{} 594 }, 595 } 596 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) 597 resp, err := tr.RoundTrip(req) 598 if err != nil { 599 t.Errorf("unexpected error for request %s: %v", reqId, err) 600 } 601 _, err = io.ReadAll(resp.Body) 602 if err != nil { 603 t.Errorf("unexpected error for request %s: %v", reqId, err) 604 } 605 reqComplete <- struct{}{} 606 } 607 // get req1 to dial-in-progress 608 go doReq("req1") 609 <-preDial 610 <-dialStarted 611 612 // get req2 to waiting on conns per host to go down below max 613 go doReq("req2") 614 <-preDial 615 select { 616 case <-dialStarted: 617 t.Error("req2 dial started while req1 dial in progress") 618 return 619 default: 620 } 621 622 // let req1 complete 623 stallDial <- struct{}{} 624 <-reqComplete 625 626 // let req2 complete 627 <-dialStarted 628 stallDial <- struct{}{} 629 <-reqComplete 630 } 631 632 func TestTransportMaxConnsPerHost(t *testing.T) { 633 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode}) 634 } 635 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) { 636 CondSkipHTTP2(t) 637 638 h := HandlerFunc(func(w ResponseWriter, r *Request) { 639 _, err := w.Write([]byte("foo")) 640 if err != nil { 641 t.Fatalf("Write: %v", err) 642 } 643 }) 644 645 ts := newClientServerTest(t, mode, h).ts 646 c := ts.Client() 647 tr := c.Transport.(*Transport) 648 tr.MaxConnsPerHost = 1 649 650 mu := sync.Mutex{} 651 var conns []net.Conn 652 var dialCnt, gotConnCnt, tlsHandshakeCnt int32 653 tr.Dial = func(network, addr string) (net.Conn, error) { 654 atomic.AddInt32(&dialCnt, 1) 655 c, err := net.Dial(network, addr) 656 mu.Lock() 657 defer mu.Unlock() 658 conns = append(conns, c) 659 return c, err 660 } 661 662 doReq := func() { 663 trace := &httptrace.ClientTrace{ 664 GotConn: func(connInfo httptrace.GotConnInfo) { 665 if !connInfo.Reused { 666 atomic.AddInt32(&gotConnCnt, 1) 667 } 668 }, 669 TLSHandshakeStart: func() { 670 atomic.AddInt32(&tlsHandshakeCnt, 1) 671 }, 672 } 673 req, _ := NewRequest("GET", ts.URL, nil) 674 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) 675 676 resp, err := c.Do(req) 677 if err != nil { 678 t.Fatalf("request failed: %v", err) 679 } 680 defer resp.Body.Close() 681 _, err = io.ReadAll(resp.Body) 682 if err != nil { 683 t.Fatalf("read body failed: %v", err) 684 } 685 } 686 687 wg := sync.WaitGroup{} 688 for i := 0; i < 10; i++ { 689 wg.Add(1) 690 go func() { 691 defer wg.Done() 692 doReq() 693 }() 694 } 695 wg.Wait() 696 697 expected := int32(tr.MaxConnsPerHost) 698 if dialCnt != expected { 699 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected) 700 } 701 if gotConnCnt != expected { 702 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected) 703 } 704 if ts.TLS != nil && tlsHandshakeCnt != expected { 705 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected) 706 } 707 708 if t.Failed() { 709 t.FailNow() 710 } 711 712 mu.Lock() 713 for _, c := range conns { 714 c.Close() 715 } 716 conns = nil 717 mu.Unlock() 718 tr.CloseIdleConnections() 719 720 doReq() 721 expected++ 722 if dialCnt != expected { 723 t.Errorf("round 2: too many dials: %d", dialCnt) 724 } 725 if gotConnCnt != expected { 726 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected) 727 } 728 if ts.TLS != nil && tlsHandshakeCnt != expected { 729 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected) 730 } 731 } 732 733 func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) { 734 run(t, testTransportMaxConnsPerHostDialCancellation, 735 testNotParallel, // because test uses SetPendingDialHooks 736 []testMode{http1Mode, https1Mode, http2Mode}, 737 ) 738 } 739 740 func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) { 741 CondSkipHTTP2(t) 742 743 h := HandlerFunc(func(w ResponseWriter, r *Request) { 744 _, err := w.Write([]byte("foo")) 745 if err != nil { 746 t.Fatalf("Write: %v", err) 747 } 748 }) 749 750 cst := newClientServerTest(t, mode, h) 751 defer cst.close() 752 ts := cst.ts 753 c := ts.Client() 754 tr := c.Transport.(*Transport) 755 tr.MaxConnsPerHost = 1 756 757 // This request is cancelled when dial is queued, which preempts dialing. 758 ctx, cancel := context.WithCancel(context.Background()) 759 defer cancel() 760 SetPendingDialHooks(cancel, nil) 761 defer SetPendingDialHooks(nil, nil) 762 763 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil) 764 _, err := c.Do(req) 765 if !errors.Is(err, context.Canceled) { 766 t.Errorf("expected error %v, got %v", context.Canceled, err) 767 } 768 769 // This request should succeed. 770 SetPendingDialHooks(nil, nil) 771 req, _ = NewRequest("GET", ts.URL, nil) 772 resp, err := c.Do(req) 773 if err != nil { 774 t.Fatalf("request failed: %v", err) 775 } 776 defer resp.Body.Close() 777 _, err = io.ReadAll(resp.Body) 778 if err != nil { 779 t.Fatalf("read body failed: %v", err) 780 } 781 } 782 783 func TestTransportRemovesDeadIdleConnections(t *testing.T) { 784 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode}) 785 } 786 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) { 787 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 788 io.WriteString(w, r.RemoteAddr) 789 })).ts 790 791 c := ts.Client() 792 tr := c.Transport.(*Transport) 793 794 doReq := func(name string) { 795 // Do a POST instead of a GET to prevent the Transport's 796 // idempotent request retry logic from kicking in... 797 res, err := c.Post(ts.URL, "", nil) 798 if err != nil { 799 t.Fatalf("%s: %v", name, err) 800 } 801 if res.StatusCode != 200 { 802 t.Fatalf("%s: %v", name, res.Status) 803 } 804 defer res.Body.Close() 805 slurp, err := io.ReadAll(res.Body) 806 if err != nil { 807 t.Fatalf("%s: %v", name, err) 808 } 809 t.Logf("%s: ok (%q)", name, slurp) 810 } 811 812 doReq("first") 813 keys1 := tr.IdleConnKeysForTesting() 814 815 ts.CloseClientConnections() 816 817 var keys2 []string 818 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { 819 keys2 = tr.IdleConnKeysForTesting() 820 if len(keys2) != 0 { 821 if d > 0 { 822 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2) 823 } 824 return false 825 } 826 return true 827 }) 828 829 doReq("second") 830 } 831 832 // Test that the Transport notices when a server hangs up on its 833 // unexpectedly (a keep-alive connection is closed). 834 func TestTransportServerClosingUnexpectedly(t *testing.T) { 835 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode}) 836 } 837 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) { 838 ts := newClientServerTest(t, mode, hostPortHandler).ts 839 c := ts.Client() 840 841 fetch := func(n, retries int) string { 842 condFatalf := func(format string, arg ...any) { 843 if retries <= 0 { 844 t.Fatalf(format, arg...) 845 } 846 t.Logf("retrying shortly after expected error: "+format, arg...) 847 time.Sleep(time.Second / time.Duration(retries)) 848 } 849 for retries >= 0 { 850 retries-- 851 res, err := c.Get(ts.URL) 852 if err != nil { 853 condFatalf("error in req #%d, GET: %v", n, err) 854 continue 855 } 856 body, err := io.ReadAll(res.Body) 857 if err != nil { 858 condFatalf("error in req #%d, ReadAll: %v", n, err) 859 continue 860 } 861 res.Body.Close() 862 return string(body) 863 } 864 panic("unreachable") 865 } 866 867 body1 := fetch(1, 0) 868 body2 := fetch(2, 0) 869 870 // Close all the idle connections in a way that's similar to 871 // the server hanging up on us. We don't use 872 // httptest.Server.CloseClientConnections because it's 873 // best-effort and stops blocking after 5 seconds. On a loaded 874 // machine running many tests concurrently it's possible for 875 // that method to be async and cause the body3 fetch below to 876 // run on an old connection. This function is synchronous. 877 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport)) 878 879 body3 := fetch(3, 5) 880 881 if body1 != body2 { 882 t.Errorf("expected body1 and body2 to be equal") 883 } 884 if body2 == body3 { 885 t.Errorf("expected body2 and body3 to be different") 886 } 887 } 888 889 // Test for https://golang.org/issue/2616 (appropriate issue number) 890 // This fails pretty reliably with GOMAXPROCS=100 or something high. 891 func TestStressSurpriseServerCloses(t *testing.T) { 892 run(t, testStressSurpriseServerCloses, []testMode{http1Mode}) 893 } 894 func testStressSurpriseServerCloses(t *testing.T, mode testMode) { 895 if testing.Short() { 896 t.Skip("skipping test in short mode") 897 } 898 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 899 w.Header().Set("Content-Length", "5") 900 w.Header().Set("Content-Type", "text/plain") 901 w.Write([]byte("Hello")) 902 w.(Flusher).Flush() 903 conn, buf, _ := w.(Hijacker).Hijack() 904 buf.Flush() 905 conn.Close() 906 })).ts 907 c := ts.Client() 908 909 // Do a bunch of traffic from different goroutines. Send to activityc 910 // after each request completes, regardless of whether it failed. 911 // If these are too high, OS X exhausts its ephemeral ports 912 // and hangs waiting for them to transition TCP states. That's 913 // not what we want to test. TODO(bradfitz): use an io.Pipe 914 // dialer for this test instead? 915 const ( 916 numClients = 20 917 reqsPerClient = 25 918 ) 919 var wg sync.WaitGroup 920 wg.Add(numClients * reqsPerClient) 921 for i := 0; i < numClients; i++ { 922 go func() { 923 for i := 0; i < reqsPerClient; i++ { 924 res, err := c.Get(ts.URL) 925 if err == nil { 926 // We expect errors since the server is 927 // hanging up on us after telling us to 928 // send more requests, so we don't 929 // actually care what the error is. 930 // But we want to close the body in cases 931 // where we won the race. 932 res.Body.Close() 933 } 934 wg.Done() 935 } 936 }() 937 } 938 939 // Make sure all the request come back, one way or another. 940 wg.Wait() 941 } 942 943 // TestTransportHeadResponses verifies that we deal with Content-Lengths 944 // with no bodies properly 945 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) } 946 func testTransportHeadResponses(t *testing.T, mode testMode) { 947 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 948 if r.Method != "HEAD" { 949 panic("expected HEAD; got " + r.Method) 950 } 951 w.Header().Set("Content-Length", "123") 952 w.WriteHeader(200) 953 })).ts 954 c := ts.Client() 955 956 for i := 0; i < 2; i++ { 957 res, err := c.Head(ts.URL) 958 if err != nil { 959 t.Errorf("error on loop %d: %v", i, err) 960 continue 961 } 962 if e, g := "123", res.Header.Get("Content-Length"); e != g { 963 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) 964 } 965 if e, g := int64(123), res.ContentLength; e != g { 966 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) 967 } 968 if all, err := io.ReadAll(res.Body); err != nil { 969 t.Errorf("loop %d: Body ReadAll: %v", i, err) 970 } else if len(all) != 0 { 971 t.Errorf("Bogus body %q", all) 972 } 973 } 974 } 975 976 // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding 977 // on responses to HEAD requests. 978 func TestTransportHeadChunkedResponse(t *testing.T) { 979 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel) 980 } 981 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) { 982 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 983 if r.Method != "HEAD" { 984 panic("expected HEAD; got " + r.Method) 985 } 986 w.Header().Set("Transfer-Encoding", "chunked") // client should ignore 987 w.Header().Set("x-client-ipport", r.RemoteAddr) 988 w.WriteHeader(200) 989 })).ts 990 c := ts.Client() 991 992 // Ensure that we wait for the readLoop to complete before 993 // calling Head again 994 didRead := make(chan bool) 995 SetReadLoopBeforeNextReadHook(func() { didRead <- true }) 996 defer SetReadLoopBeforeNextReadHook(nil) 997 998 res1, err := c.Head(ts.URL) 999 <-didRead 1000 1001 if err != nil { 1002 t.Fatalf("request 1 error: %v", err) 1003 } 1004 1005 res2, err := c.Head(ts.URL) 1006 <-didRead 1007 1008 if err != nil { 1009 t.Fatalf("request 2 error: %v", err) 1010 } 1011 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 { 1012 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2) 1013 } 1014 } 1015 1016 var roundTripTests = []struct { 1017 accept string 1018 expectAccept string 1019 compressed bool 1020 }{ 1021 // Requests with no accept-encoding header use transparent compression 1022 {"", "gzip", false}, 1023 // Requests with other accept-encoding should pass through unmodified 1024 {"foo", "foo", false}, 1025 // Requests with accept-encoding == gzip should be passed through 1026 {"gzip", "gzip", true}, 1027 } 1028 1029 // Test that the modification made to the Request by the RoundTripper is cleaned up 1030 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) } 1031 func testRoundTripGzip(t *testing.T, mode testMode) { 1032 const responseBody = "test response body" 1033 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { 1034 accept := req.Header.Get("Accept-Encoding") 1035 if expect := req.FormValue("expect_accept"); accept != expect { 1036 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", 1037 req.FormValue("testnum"), accept, expect) 1038 } 1039 if accept == "gzip" { 1040 rw.Header().Set("Content-Encoding", "gzip") 1041 gz := gzip.NewWriter(rw) 1042 gz.Write([]byte(responseBody)) 1043 gz.Close() 1044 } else { 1045 rw.Header().Set("Content-Encoding", accept) 1046 rw.Write([]byte(responseBody)) 1047 } 1048 })).ts 1049 tr := ts.Client().Transport.(*Transport) 1050 1051 for i, test := range roundTripTests { 1052 // Test basic request (no accept-encoding) 1053 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil) 1054 if test.accept != "" { 1055 req.Header.Set("Accept-Encoding", test.accept) 1056 } 1057 res, err := tr.RoundTrip(req) 1058 if err != nil { 1059 t.Errorf("%d. RoundTrip: %v", i, err) 1060 continue 1061 } 1062 var body []byte 1063 if test.compressed { 1064 var r *gzip.Reader 1065 r, err = gzip.NewReader(res.Body) 1066 if err != nil { 1067 t.Errorf("%d. gzip NewReader: %v", i, err) 1068 continue 1069 } 1070 body, err = io.ReadAll(r) 1071 res.Body.Close() 1072 } else { 1073 body, err = io.ReadAll(res.Body) 1074 } 1075 if err != nil { 1076 t.Errorf("%d. Error: %q", i, err) 1077 continue 1078 } 1079 if g, e := string(body), responseBody; g != e { 1080 t.Errorf("%d. body = %q; want %q", i, g, e) 1081 } 1082 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { 1083 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e) 1084 } 1085 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { 1086 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) 1087 } 1088 } 1089 1090 } 1091 1092 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) } 1093 func testTransportGzip(t *testing.T, mode testMode) { 1094 if mode == http2Mode { 1095 t.Skip("https://go.dev/issue/56020") 1096 } 1097 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" 1098 const nRandBytes = 1024 * 1024 1099 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { 1100 if req.Method == "HEAD" { 1101 if g := req.Header.Get("Accept-Encoding"); g != "" { 1102 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) 1103 } 1104 return 1105 } 1106 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { 1107 t.Errorf("Accept-Encoding = %q, want %q", g, e) 1108 } 1109 rw.Header().Set("Content-Encoding", "gzip") 1110 1111 var w io.Writer = rw 1112 var buf bytes.Buffer 1113 if req.FormValue("chunked") == "0" { 1114 w = &buf 1115 defer io.Copy(rw, &buf) 1116 defer func() { 1117 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) 1118 }() 1119 } 1120 gz := gzip.NewWriter(w) 1121 gz.Write([]byte(testString)) 1122 if req.FormValue("body") == "large" { 1123 io.CopyN(gz, rand.Reader, nRandBytes) 1124 } 1125 gz.Close() 1126 })).ts 1127 c := ts.Client() 1128 1129 for _, chunked := range []string{"1", "0"} { 1130 // First fetch something large, but only read some of it. 1131 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) 1132 if err != nil { 1133 t.Fatalf("large get: %v", err) 1134 } 1135 buf := make([]byte, len(testString)) 1136 n, err := io.ReadFull(res.Body, buf) 1137 if err != nil { 1138 t.Fatalf("partial read of large response: size=%d, %v", n, err) 1139 } 1140 if e, g := testString, string(buf); e != g { 1141 t.Errorf("partial read got %q, expected %q", g, e) 1142 } 1143 res.Body.Close() 1144 // Read on the body, even though it's closed 1145 n, err = res.Body.Read(buf) 1146 if n != 0 || err == nil { 1147 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) 1148 } 1149 1150 // Then something small. 1151 res, err = c.Get(ts.URL + "/?chunked=" + chunked) 1152 if err != nil { 1153 t.Fatal(err) 1154 } 1155 body, err := io.ReadAll(res.Body) 1156 if err != nil { 1157 t.Fatal(err) 1158 } 1159 if g, e := string(body), testString; g != e { 1160 t.Fatalf("body = %q; want %q", g, e) 1161 } 1162 if g, e := res.Header.Get("Content-Encoding"), ""; g != e { 1163 t.Fatalf("Content-Encoding = %q; want %q", g, e) 1164 } 1165 1166 // Read on the body after it's been fully read: 1167 n, err = res.Body.Read(buf) 1168 if n != 0 || err == nil { 1169 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) 1170 } 1171 res.Body.Close() 1172 n, err = res.Body.Read(buf) 1173 if n != 0 || err == nil { 1174 t.Errorf("expected Read error after Close; got %d, %v", n, err) 1175 } 1176 } 1177 1178 // And a HEAD request too, because they're always weird. 1179 res, err := c.Head(ts.URL) 1180 if err != nil { 1181 t.Fatalf("Head: %v", err) 1182 } 1183 if res.StatusCode != 200 { 1184 t.Errorf("Head status=%d; want=200", res.StatusCode) 1185 } 1186 } 1187 1188 // If a request has Expect:100-continue header, the request blocks sending body until the first response. 1189 // Premature consumption of the request body should not be occurred. 1190 func TestTransportExpect100Continue(t *testing.T) { 1191 run(t, testTransportExpect100Continue, []testMode{http1Mode}) 1192 } 1193 func testTransportExpect100Continue(t *testing.T, mode testMode) { 1194 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { 1195 switch req.URL.Path { 1196 case "/100": 1197 // This endpoint implicitly responds 100 Continue and reads body. 1198 if _, err := io.Copy(io.Discard, req.Body); err != nil { 1199 t.Error("Failed to read Body", err) 1200 } 1201 rw.WriteHeader(StatusOK) 1202 case "/200": 1203 // Go 1.5 adds Connection: close header if the client expect 1204 // continue but not entire request body is consumed. 1205 rw.WriteHeader(StatusOK) 1206 case "/500": 1207 rw.WriteHeader(StatusInternalServerError) 1208 case "/keepalive": 1209 // This hijacked endpoint responds error without Connection:close. 1210 _, bufrw, err := rw.(Hijacker).Hijack() 1211 if err != nil { 1212 log.Fatal(err) 1213 } 1214 bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n") 1215 bufrw.WriteString("Content-Length: 0\r\n\r\n") 1216 bufrw.Flush() 1217 case "/timeout": 1218 // This endpoint tries to read body without 100 (Continue) response. 1219 // After ExpectContinueTimeout, the reading will be started. 1220 conn, bufrw, err := rw.(Hijacker).Hijack() 1221 if err != nil { 1222 log.Fatal(err) 1223 } 1224 if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil { 1225 t.Error("Failed to read Body", err) 1226 } 1227 bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") 1228 bufrw.Flush() 1229 conn.Close() 1230 } 1231 1232 })).ts 1233 1234 tests := []struct { 1235 path string 1236 body []byte 1237 sent int 1238 status int 1239 }{ 1240 {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent. 1241 {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent. 1242 {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent. 1243 {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent. 1244 {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. 1245 } 1246 1247 c := ts.Client() 1248 for i, v := range tests { 1249 tr := &Transport{ 1250 ExpectContinueTimeout: 2 * time.Second, 1251 } 1252 defer tr.CloseIdleConnections() 1253 c.Transport = tr 1254 body := bytes.NewReader(v.body) 1255 req, err := NewRequest("PUT", ts.URL+v.path, body) 1256 if err != nil { 1257 t.Fatal(err) 1258 } 1259 req.Header.Set("Expect", "100-continue") 1260 req.ContentLength = int64(len(v.body)) 1261 1262 resp, err := c.Do(req) 1263 if err != nil { 1264 t.Fatal(err) 1265 } 1266 resp.Body.Close() 1267 1268 sent := len(v.body) - body.Len() 1269 if v.status != resp.StatusCode { 1270 t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path) 1271 } 1272 if v.sent != sent { 1273 t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path) 1274 } 1275 } 1276 } 1277 1278 func TestSOCKS5Proxy(t *testing.T) { 1279 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode}) 1280 } 1281 func testSOCKS5Proxy(t *testing.T, mode testMode) { 1282 ch := make(chan string, 1) 1283 l := newLocalListener(t) 1284 defer l.Close() 1285 defer close(ch) 1286 proxy := func(t *testing.T) { 1287 s, err := l.Accept() 1288 if err != nil { 1289 t.Errorf("socks5 proxy Accept(): %v", err) 1290 return 1291 } 1292 defer s.Close() 1293 var buf [22]byte 1294 if _, err := io.ReadFull(s, buf[:3]); err != nil { 1295 t.Errorf("socks5 proxy initial read: %v", err) 1296 return 1297 } 1298 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { 1299 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want) 1300 return 1301 } 1302 if _, err := s.Write([]byte{5, 0}); err != nil { 1303 t.Errorf("socks5 proxy initial write: %v", err) 1304 return 1305 } 1306 if _, err := io.ReadFull(s, buf[:4]); err != nil { 1307 t.Errorf("socks5 proxy second read: %v", err) 1308 return 1309 } 1310 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { 1311 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want) 1312 return 1313 } 1314 var ipLen int 1315 switch buf[3] { 1316 case 1: 1317 ipLen = net.IPv4len 1318 case 4: 1319 ipLen = net.IPv6len 1320 default: 1321 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4]) 1322 return 1323 } 1324 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil { 1325 t.Errorf("socks5 proxy address read: %v", err) 1326 return 1327 } 1328 ip := net.IP(buf[4 : ipLen+4]) 1329 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6]) 1330 copy(buf[:3], []byte{5, 0, 0}) 1331 if _, err := s.Write(buf[:ipLen+6]); err != nil { 1332 t.Errorf("socks5 proxy connect write: %v", err) 1333 return 1334 } 1335 ch <- fmt.Sprintf("proxy for %s:%d", ip, port) 1336 1337 // Implement proxying. 1338 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) 1339 targetConn, err := net.Dial("tcp", targetHost) 1340 if err != nil { 1341 t.Errorf("net.Dial failed") 1342 return 1343 } 1344 go io.Copy(targetConn, s) 1345 io.Copy(s, targetConn) // Wait for the client to close the socket. 1346 targetConn.Close() 1347 } 1348 1349 pu, err := url.Parse("socks5://" + l.Addr().String()) 1350 if err != nil { 1351 t.Fatal(err) 1352 } 1353 1354 sentinelHeader := "X-Sentinel" 1355 sentinelValue := "12345" 1356 h := HandlerFunc(func(w ResponseWriter, r *Request) { 1357 w.Header().Set(sentinelHeader, sentinelValue) 1358 }) 1359 for _, useTLS := range []bool{false, true} { 1360 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { 1361 ts := newClientServerTest(t, mode, h).ts 1362 go proxy(t) 1363 c := ts.Client() 1364 c.Transport.(*Transport).Proxy = ProxyURL(pu) 1365 r, err := c.Head(ts.URL) 1366 if err != nil { 1367 t.Fatal(err) 1368 } 1369 if r.Header.Get(sentinelHeader) != sentinelValue { 1370 t.Errorf("Failed to retrieve sentinel value") 1371 } 1372 got := <-ch 1373 ts.Close() 1374 tsu, err := url.Parse(ts.URL) 1375 if err != nil { 1376 t.Fatal(err) 1377 } 1378 want := "proxy for " + tsu.Host 1379 if got != want { 1380 t.Errorf("got %q, want %q", got, want) 1381 } 1382 }) 1383 } 1384 } 1385 1386 func TestTransportProxy(t *testing.T) { 1387 defer afterTest(t) 1388 testCases := []struct{ siteMode, proxyMode testMode }{ 1389 {http1Mode, http1Mode}, 1390 {http1Mode, https1Mode}, 1391 {https1Mode, http1Mode}, 1392 {https1Mode, https1Mode}, 1393 } 1394 for _, testCase := range testCases { 1395 siteMode := testCase.siteMode 1396 proxyMode := testCase.proxyMode 1397 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) { 1398 siteCh := make(chan *Request, 1) 1399 h1 := HandlerFunc(func(w ResponseWriter, r *Request) { 1400 siteCh <- r 1401 }) 1402 proxyCh := make(chan *Request, 1) 1403 h2 := HandlerFunc(func(w ResponseWriter, r *Request) { 1404 proxyCh <- r 1405 // Implement an entire CONNECT proxy 1406 if r.Method == "CONNECT" { 1407 hijacker, ok := w.(Hijacker) 1408 if !ok { 1409 t.Errorf("hijack not allowed") 1410 return 1411 } 1412 clientConn, _, err := hijacker.Hijack() 1413 if err != nil { 1414 t.Errorf("hijacking failed") 1415 return 1416 } 1417 res := &Response{ 1418 StatusCode: StatusOK, 1419 Proto: "HTTP/1.1", 1420 ProtoMajor: 1, 1421 ProtoMinor: 1, 1422 Header: make(Header), 1423 } 1424 1425 targetConn, err := net.Dial("tcp", r.URL.Host) 1426 if err != nil { 1427 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) 1428 return 1429 } 1430 1431 if err := res.Write(clientConn); err != nil { 1432 t.Errorf("Writing 200 OK failed: %v", err) 1433 return 1434 } 1435 1436 go io.Copy(targetConn, clientConn) 1437 go func() { 1438 io.Copy(clientConn, targetConn) 1439 targetConn.Close() 1440 }() 1441 } 1442 }) 1443 ts := newClientServerTest(t, siteMode, h1).ts 1444 proxy := newClientServerTest(t, proxyMode, h2).ts 1445 1446 pu, err := url.Parse(proxy.URL) 1447 if err != nil { 1448 t.Fatal(err) 1449 } 1450 1451 // If neither server is HTTPS or both are, then c may be derived from either. 1452 // If only one server is HTTPS, c must be derived from that server in order 1453 // to ensure that it is configured to use the fake root CA from testcert.go. 1454 c := proxy.Client() 1455 if siteMode == https1Mode { 1456 c = ts.Client() 1457 } 1458 1459 c.Transport.(*Transport).Proxy = ProxyURL(pu) 1460 if _, err := c.Head(ts.URL); err != nil { 1461 t.Error(err) 1462 } 1463 got := <-proxyCh 1464 c.Transport.(*Transport).CloseIdleConnections() 1465 ts.Close() 1466 proxy.Close() 1467 if siteMode == https1Mode { 1468 // First message should be a CONNECT, asking for a socket to the real server, 1469 if got.Method != "CONNECT" { 1470 t.Errorf("Wrong method for secure proxying: %q", got.Method) 1471 } 1472 gotHost := got.URL.Host 1473 pu, err := url.Parse(ts.URL) 1474 if err != nil { 1475 t.Fatal("Invalid site URL") 1476 } 1477 if wantHost := pu.Host; gotHost != wantHost { 1478 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost) 1479 } 1480 1481 // The next message on the channel should be from the site's server. 1482 next := <-siteCh 1483 if next.Method != "HEAD" { 1484 t.Errorf("Wrong method at destination: %s", next.Method) 1485 } 1486 if nextURL := next.URL.String(); nextURL != "/" { 1487 t.Errorf("Wrong URL at destination: %s", nextURL) 1488 } 1489 } else { 1490 if got.Method != "HEAD" { 1491 t.Errorf("Wrong method for destination: %q", got.Method) 1492 } 1493 gotURL := got.URL.String() 1494 wantURL := ts.URL + "/" 1495 if gotURL != wantURL { 1496 t.Errorf("Got URL %q, want %q", gotURL, wantURL) 1497 } 1498 } 1499 }) 1500 } 1501 } 1502 1503 func TestOnProxyConnectResponse(t *testing.T) { 1504 1505 var tcases = []struct { 1506 proxyStatusCode int 1507 err error 1508 }{ 1509 { 1510 StatusOK, 1511 nil, 1512 }, 1513 { 1514 StatusForbidden, 1515 errors.New("403"), 1516 }, 1517 } 1518 for _, tcase := range tcases { 1519 h1 := HandlerFunc(func(w ResponseWriter, r *Request) { 1520 1521 }) 1522 1523 h2 := HandlerFunc(func(w ResponseWriter, r *Request) { 1524 // Implement an entire CONNECT proxy 1525 if r.Method == "CONNECT" { 1526 if tcase.proxyStatusCode != StatusOK { 1527 w.WriteHeader(tcase.proxyStatusCode) 1528 return 1529 } 1530 hijacker, ok := w.(Hijacker) 1531 if !ok { 1532 t.Errorf("hijack not allowed") 1533 return 1534 } 1535 clientConn, _, err := hijacker.Hijack() 1536 if err != nil { 1537 t.Errorf("hijacking failed") 1538 return 1539 } 1540 res := &Response{ 1541 StatusCode: StatusOK, 1542 Proto: "HTTP/1.1", 1543 ProtoMajor: 1, 1544 ProtoMinor: 1, 1545 Header: make(Header), 1546 } 1547 1548 targetConn, err := net.Dial("tcp", r.URL.Host) 1549 if err != nil { 1550 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) 1551 return 1552 } 1553 1554 if err := res.Write(clientConn); err != nil { 1555 t.Errorf("Writing 200 OK failed: %v", err) 1556 return 1557 } 1558 1559 go io.Copy(targetConn, clientConn) 1560 go func() { 1561 io.Copy(clientConn, targetConn) 1562 targetConn.Close() 1563 }() 1564 } 1565 }) 1566 ts := newClientServerTest(t, https1Mode, h1).ts 1567 proxy := newClientServerTest(t, https1Mode, h2).ts 1568 1569 pu, err := url.Parse(proxy.URL) 1570 if err != nil { 1571 t.Fatal(err) 1572 } 1573 1574 c := proxy.Client() 1575 1576 var ( 1577 dials atomic.Int32 1578 closes atomic.Int32 1579 ) 1580 c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { 1581 conn, err := net.Dial(network, addr) 1582 if err != nil { 1583 return nil, err 1584 } 1585 dials.Add(1) 1586 return noteCloseConn{ 1587 Conn: conn, 1588 closeFunc: func() { 1589 closes.Add(1) 1590 }, 1591 }, nil 1592 } 1593 1594 c.Transport.(*Transport).Proxy = ProxyURL(pu) 1595 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error { 1596 if proxyURL.String() != pu.String() { 1597 t.Errorf("proxy url got %s, want %s", proxyURL, pu) 1598 } 1599 1600 if "https://"+connectReq.URL.String() != ts.URL { 1601 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL) 1602 } 1603 return tcase.err 1604 } 1605 wantCloses := int32(0) 1606 if _, err := c.Head(ts.URL); err != nil { 1607 wantCloses = 1 1608 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) { 1609 t.Errorf("got %v, want %v", err, tcase.err) 1610 } 1611 } else { 1612 if tcase.err != nil { 1613 t.Errorf("got %v, want nil", err) 1614 } 1615 } 1616 if got, want := dials.Load(), int32(1); got != want { 1617 t.Errorf("got %v dials, want %v", got, want) 1618 } 1619 // #64804: If OnProxyConnectResponse returns an error, we should close the conn. 1620 if got, want := closes.Load(), wantCloses; got != want { 1621 t.Errorf("got %v closes, want %v", got, want) 1622 } 1623 } 1624 } 1625 1626 // Issue 28012: verify that the Transport closes its TCP connection to http proxies 1627 // when they're slow to reply to HTTPS CONNECT responses. 1628 func TestTransportProxyHTTPSConnectLeak(t *testing.T) { 1629 setParallel(t) 1630 defer afterTest(t) 1631 1632 ctx, cancel := context.WithCancel(context.Background()) 1633 defer cancel() 1634 1635 ln := newLocalListener(t) 1636 defer ln.Close() 1637 listenerDone := make(chan struct{}) 1638 go func() { 1639 defer close(listenerDone) 1640 c, err := ln.Accept() 1641 if err != nil { 1642 t.Errorf("Accept: %v", err) 1643 return 1644 } 1645 defer c.Close() 1646 // Read the CONNECT request 1647 br := bufio.NewReader(c) 1648 cr, err := ReadRequest(br) 1649 if err != nil { 1650 t.Errorf("proxy server failed to read CONNECT request") 1651 return 1652 } 1653 if cr.Method != "CONNECT" { 1654 t.Errorf("unexpected method %q", cr.Method) 1655 return 1656 } 1657 1658 // Now hang and never write a response; instead, cancel the request and wait 1659 // for the client to close. 1660 // (Prior to Issue 28012 being fixed, we never closed.) 1661 cancel() 1662 var buf [1]byte 1663 _, err = br.Read(buf[:]) 1664 if err != io.EOF { 1665 t.Errorf("proxy server Read err = %v; want EOF", err) 1666 } 1667 return 1668 }() 1669 1670 c := &Client{ 1671 Transport: &Transport{ 1672 Proxy: func(*Request) (*url.URL, error) { 1673 return url.Parse("http://" + ln.Addr().String()) 1674 }, 1675 }, 1676 } 1677 req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil) 1678 if err != nil { 1679 t.Fatal(err) 1680 } 1681 _, err = c.Do(req) 1682 if err == nil { 1683 t.Errorf("unexpected Get success") 1684 } 1685 1686 // Wait unconditionally for the listener goroutine to exit: this should never 1687 // hang, so if it does we want a full goroutine dump — and that's exactly what 1688 // the testing package will give us when the test run times out. 1689 <-listenerDone 1690 } 1691 1692 // Issue 16997: test transport dial preserves typed errors 1693 func TestTransportDialPreservesNetOpProxyError(t *testing.T) { 1694 defer afterTest(t) 1695 1696 var errDial = errors.New("some dial error") 1697 1698 tr := &Transport{ 1699 Proxy: func(*Request) (*url.URL, error) { 1700 return url.Parse("http://proxy.fake.tld/") 1701 }, 1702 Dial: func(string, string) (net.Conn, error) { 1703 return nil, errDial 1704 }, 1705 } 1706 defer tr.CloseIdleConnections() 1707 1708 c := &Client{Transport: tr} 1709 req, _ := NewRequest("GET", "http://fake.tld", nil) 1710 res, err := c.Do(req) 1711 if err == nil { 1712 res.Body.Close() 1713 t.Fatal("wanted a non-nil error") 1714 } 1715 1716 uerr, ok := err.(*url.Error) 1717 if !ok { 1718 t.Fatalf("got %T, want *url.Error", err) 1719 } 1720 oe, ok := uerr.Err.(*net.OpError) 1721 if !ok { 1722 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err) 1723 } 1724 want := &net.OpError{ 1725 Op: "proxyconnect", 1726 Net: "tcp", 1727 Err: errDial, // original error, unwrapped. 1728 } 1729 if !reflect.DeepEqual(oe, want) { 1730 t.Errorf("Got error %#v; want %#v", oe, want) 1731 } 1732 } 1733 1734 // Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader. 1735 // 1736 // (A bug caused dialConn to instead write the per-request Proxy-Authorization 1737 // header through to the shared Header instance, introducing a data race.) 1738 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { 1739 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader) 1740 } 1741 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) { 1742 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts 1743 defer proxy.Close() 1744 c := proxy.Client() 1745 1746 tr := c.Transport.(*Transport) 1747 tr.Proxy = func(*Request) (*url.URL, error) { 1748 u, _ := url.Parse(proxy.URL) 1749 u.User = url.UserPassword("aladdin", "opensesame") 1750 return u, nil 1751 } 1752 h := tr.ProxyConnectHeader 1753 if h == nil { 1754 h = make(Header) 1755 } 1756 tr.ProxyConnectHeader = h.Clone() 1757 1758 req, err := NewRequest("GET", "https://golang.fake.tld/", nil) 1759 if err != nil { 1760 t.Fatal(err) 1761 } 1762 _, err = c.Do(req) 1763 if err == nil { 1764 t.Errorf("unexpected Get success") 1765 } 1766 1767 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) { 1768 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h) 1769 } 1770 } 1771 1772 // TestTransportGzipRecursive sends a gzip quine and checks that the 1773 // client gets the same value back. This is more cute than anything, 1774 // but checks that we don't recurse forever, and checks that 1775 // Content-Encoding is removed. 1776 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) } 1777 func testTransportGzipRecursive(t *testing.T, mode testMode) { 1778 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1779 w.Header().Set("Content-Encoding", "gzip") 1780 w.Write(rgz) 1781 })).ts 1782 1783 c := ts.Client() 1784 res, err := c.Get(ts.URL) 1785 if err != nil { 1786 t.Fatal(err) 1787 } 1788 body, err := io.ReadAll(res.Body) 1789 if err != nil { 1790 t.Fatal(err) 1791 } 1792 if !bytes.Equal(body, rgz) { 1793 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", 1794 body, rgz) 1795 } 1796 if g, e := res.Header.Get("Content-Encoding"), ""; g != e { 1797 t.Fatalf("Content-Encoding = %q; want %q", g, e) 1798 } 1799 } 1800 1801 // golang.org/issue/7750: request fails when server replies with 1802 // a short gzip body 1803 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) } 1804 func testTransportGzipShort(t *testing.T, mode testMode) { 1805 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1806 w.Header().Set("Content-Encoding", "gzip") 1807 w.Write([]byte{0x1f, 0x8b}) 1808 })).ts 1809 1810 c := ts.Client() 1811 res, err := c.Get(ts.URL) 1812 if err != nil { 1813 t.Fatal(err) 1814 } 1815 defer res.Body.Close() 1816 _, err = io.ReadAll(res.Body) 1817 if err == nil { 1818 t.Fatal("Expect an error from reading a body.") 1819 } 1820 if err != io.ErrUnexpectedEOF { 1821 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err) 1822 } 1823 } 1824 1825 // Wait until number of goroutines is no greater than nmax, or time out. 1826 func waitNumGoroutine(nmax int) int { 1827 nfinal := runtime.NumGoroutine() 1828 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- { 1829 time.Sleep(50 * time.Millisecond) 1830 runtime.GC() 1831 nfinal = runtime.NumGoroutine() 1832 } 1833 return nfinal 1834 } 1835 1836 // tests that persistent goroutine connections shut down when no longer desired. 1837 func TestTransportPersistConnLeak(t *testing.T) { 1838 run(t, testTransportPersistConnLeak, testNotParallel) 1839 } 1840 func testTransportPersistConnLeak(t *testing.T, mode testMode) { 1841 if mode == http2Mode { 1842 t.Skip("flaky in HTTP/2") 1843 } 1844 // Not parallel: counts goroutines 1845 1846 const numReq = 25 1847 gotReqCh := make(chan bool, numReq) 1848 unblockCh := make(chan bool, numReq) 1849 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1850 gotReqCh <- true 1851 <-unblockCh 1852 w.Header().Set("Content-Length", "0") 1853 w.WriteHeader(204) 1854 })).ts 1855 c := ts.Client() 1856 tr := c.Transport.(*Transport) 1857 1858 n0 := runtime.NumGoroutine() 1859 1860 didReqCh := make(chan bool, numReq) 1861 failed := make(chan bool, numReq) 1862 for i := 0; i < numReq; i++ { 1863 go func() { 1864 res, err := c.Get(ts.URL) 1865 didReqCh <- true 1866 if err != nil { 1867 t.Logf("client fetch error: %v", err) 1868 failed <- true 1869 return 1870 } 1871 res.Body.Close() 1872 }() 1873 } 1874 1875 // Wait for all goroutines to be stuck in the Handler. 1876 for i := 0; i < numReq; i++ { 1877 select { 1878 case <-gotReqCh: 1879 // ok 1880 case <-failed: 1881 // Not great but not what we are testing: 1882 // sometimes an overloaded system will fail to make all the connections. 1883 } 1884 } 1885 1886 nhigh := runtime.NumGoroutine() 1887 1888 // Tell all handlers to unblock and reply. 1889 close(unblockCh) 1890 1891 // Wait for all HTTP clients to be done. 1892 for i := 0; i < numReq; i++ { 1893 <-didReqCh 1894 } 1895 1896 tr.CloseIdleConnections() 1897 nfinal := waitNumGoroutine(n0 + 5) 1898 1899 growth := nfinal - n0 1900 1901 // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. 1902 // Previously we were leaking one per numReq. 1903 if int(growth) > 5 { 1904 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) 1905 t.Error("too many new goroutines") 1906 } 1907 } 1908 1909 // golang.org/issue/4531: Transport leaks goroutines when 1910 // request.ContentLength is explicitly short 1911 func TestTransportPersistConnLeakShortBody(t *testing.T) { 1912 run(t, testTransportPersistConnLeakShortBody, testNotParallel) 1913 } 1914 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) { 1915 if mode == http2Mode { 1916 t.Skip("flaky in HTTP/2") 1917 } 1918 1919 // Not parallel: measures goroutines. 1920 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1921 })).ts 1922 c := ts.Client() 1923 tr := c.Transport.(*Transport) 1924 1925 n0 := runtime.NumGoroutine() 1926 body := []byte("Hello") 1927 for i := 0; i < 20; i++ { 1928 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) 1929 if err != nil { 1930 t.Fatal(err) 1931 } 1932 req.ContentLength = int64(len(body) - 2) // explicitly short 1933 _, err = c.Do(req) 1934 if err == nil { 1935 t.Fatal("Expect an error from writing too long of a body.") 1936 } 1937 } 1938 nhigh := runtime.NumGoroutine() 1939 tr.CloseIdleConnections() 1940 nfinal := waitNumGoroutine(n0 + 5) 1941 1942 growth := nfinal - n0 1943 1944 // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. 1945 // Previously we were leaking one per numReq. 1946 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) 1947 if int(growth) > 5 { 1948 t.Error("too many new goroutines") 1949 } 1950 } 1951 1952 // A countedConn is a net.Conn that decrements an atomic counter when finalized. 1953 type countedConn struct { 1954 net.Conn 1955 } 1956 1957 // A countingDialer dials connections and counts the number that remain reachable. 1958 type countingDialer struct { 1959 dialer net.Dialer 1960 mu sync.Mutex 1961 total, live int64 1962 } 1963 1964 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 1965 conn, err := d.dialer.DialContext(ctx, network, address) 1966 if err != nil { 1967 return nil, err 1968 } 1969 1970 counted := new(countedConn) 1971 counted.Conn = conn 1972 1973 d.mu.Lock() 1974 defer d.mu.Unlock() 1975 d.total++ 1976 d.live++ 1977 1978 runtime.SetFinalizer(counted, d.decrement) 1979 return counted, nil 1980 } 1981 1982 func (d *countingDialer) decrement(*countedConn) { 1983 d.mu.Lock() 1984 defer d.mu.Unlock() 1985 d.live-- 1986 } 1987 1988 func (d *countingDialer) Read() (total, live int64) { 1989 d.mu.Lock() 1990 defer d.mu.Unlock() 1991 return d.total, d.live 1992 } 1993 1994 func TestTransportPersistConnLeakNeverIdle(t *testing.T) { 1995 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode}) 1996 } 1997 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) { 1998 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1999 // Close every connection so that it cannot be kept alive. 2000 conn, _, err := w.(Hijacker).Hijack() 2001 if err != nil { 2002 t.Errorf("Hijack failed unexpectedly: %v", err) 2003 return 2004 } 2005 conn.Close() 2006 })).ts 2007 2008 var d countingDialer 2009 c := ts.Client() 2010 c.Transport.(*Transport).DialContext = d.DialContext 2011 2012 body := []byte("Hello") 2013 for i := 0; ; i++ { 2014 total, live := d.Read() 2015 if live < total { 2016 break 2017 } 2018 if i >= 1<<12 { 2019 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i) 2020 } 2021 2022 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) 2023 if err != nil { 2024 t.Fatal(err) 2025 } 2026 _, err = c.Do(req) 2027 if err == nil { 2028 t.Fatal("expected broken connection") 2029 } 2030 2031 runtime.GC() 2032 } 2033 } 2034 2035 type countedContext struct { 2036 context.Context 2037 } 2038 2039 type contextCounter struct { 2040 mu sync.Mutex 2041 live int64 2042 } 2043 2044 func (cc *contextCounter) Track(ctx context.Context) context.Context { 2045 counted := new(countedContext) 2046 counted.Context = ctx 2047 cc.mu.Lock() 2048 defer cc.mu.Unlock() 2049 cc.live++ 2050 runtime.SetFinalizer(counted, cc.decrement) 2051 return counted 2052 } 2053 2054 func (cc *contextCounter) decrement(*countedContext) { 2055 cc.mu.Lock() 2056 defer cc.mu.Unlock() 2057 cc.live-- 2058 } 2059 2060 func (cc *contextCounter) Read() (live int64) { 2061 cc.mu.Lock() 2062 defer cc.mu.Unlock() 2063 return cc.live 2064 } 2065 2066 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) { 2067 run(t, testTransportPersistConnContextLeakMaxConnsPerHost) 2068 } 2069 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) { 2070 if mode == http2Mode { 2071 t.Skip("https://go.dev/issue/56021") 2072 } 2073 2074 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 2075 runtime.Gosched() 2076 w.WriteHeader(StatusOK) 2077 })).ts 2078 2079 c := ts.Client() 2080 c.Transport.(*Transport).MaxConnsPerHost = 1 2081 2082 ctx := context.Background() 2083 body := []byte("Hello") 2084 doPosts := func(cc *contextCounter) { 2085 var wg sync.WaitGroup 2086 for n := 64; n > 0; n-- { 2087 wg.Add(1) 2088 go func() { 2089 defer wg.Done() 2090 2091 ctx := cc.Track(ctx) 2092 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) 2093 if err != nil { 2094 t.Error(err) 2095 } 2096 2097 _, err = c.Do(req.WithContext(ctx)) 2098 if err != nil { 2099 t.Errorf("Do failed with error: %v", err) 2100 } 2101 }() 2102 } 2103 wg.Wait() 2104 } 2105 2106 var initialCC contextCounter 2107 doPosts(&initialCC) 2108 2109 // flushCC exists only to put pressure on the GC to finalize the initialCC 2110 // contexts: the flushCC allocations should eventually displace the initialCC 2111 // allocations. 2112 var flushCC contextCounter 2113 for i := 0; ; i++ { 2114 live := initialCC.Read() 2115 if live == 0 { 2116 break 2117 } 2118 if i >= 100 { 2119 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i) 2120 } 2121 doPosts(&flushCC) 2122 runtime.GC() 2123 } 2124 } 2125 2126 // This used to crash; https://golang.org/issue/3266 2127 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) } 2128 func testTransportIdleConnCrash(t *testing.T, mode testMode) { 2129 var tr *Transport 2130 2131 unblockCh := make(chan bool, 1) 2132 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 2133 <-unblockCh 2134 tr.CloseIdleConnections() 2135 })).ts 2136 c := ts.Client() 2137 tr = c.Transport.(*Transport) 2138 2139 didreq := make(chan bool) 2140 go func() { 2141 res, err := c.Get(ts.URL) 2142 if err != nil { 2143 t.Error(err) 2144 } else { 2145 res.Body.Close() // returns idle conn 2146 } 2147 didreq <- true 2148 }() 2149 unblockCh <- true 2150 <-didreq 2151 } 2152 2153 // Test that the transport doesn't close the TCP connection early, 2154 // before the response body has been read. This was a regression 2155 // which sadly lacked a triggering test. The large response body made 2156 // the old race easier to trigger. 2157 func TestIssue3644(t *testing.T) { run(t, testIssue3644) } 2158 func testIssue3644(t *testing.T, mode testMode) { 2159 const numFoos = 5000 2160 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 2161 w.Header().Set("Connection", "close") 2162 for i := 0; i < numFoos; i++ { 2163 w.Write([]byte("foo ")) 2164 } 2165 })).ts 2166 c := ts.Client() 2167 res, err := c.Get(ts.URL) 2168 if err != nil { 2169 t.Fatal(err) 2170 } 2171 defer res.Body.Close() 2172 bs, err := io.ReadAll(res.Body) 2173 if err != nil { 2174 t.Fatal(err) 2175 } 2176 if len(bs) != numFoos*len("foo ") { 2177 t.Errorf("unexpected response length") 2178 } 2179 } 2180 2181 // Test that a client receives a server's reply, even if the server doesn't read 2182 // the entire request body. 2183 func TestIssue3595(t *testing.T) { 2184 // Not parallel: modifies the global rstAvoidanceDelay. 2185 run(t, testIssue3595, testNotParallel) 2186 } 2187 func testIssue3595(t *testing.T, mode testMode) { 2188 runTimeSensitiveTest(t, []time.Duration{ 2189 1 * time.Millisecond, 2190 5 * time.Millisecond, 2191 10 * time.Millisecond, 2192 50 * time.Millisecond, 2193 100 * time.Millisecond, 2194 500 * time.Millisecond, 2195 time.Second, 2196 5 * time.Second, 2197 }, func(t *testing.T, timeout time.Duration) error { 2198 SetRSTAvoidanceDelay(t, timeout) 2199 t.Logf("set RST avoidance delay to %v", timeout) 2200 2201 const deniedMsg = "sorry, denied." 2202 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 2203 Error(w, deniedMsg, StatusUnauthorized) 2204 })) 2205 // We need to close cst explicitly here so that in-flight server 2206 // requests don't race with the call to SetRSTAvoidanceDelay for a retry. 2207 defer cst.close() 2208 ts := cst.ts 2209 c := ts.Client() 2210 2211 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) 2212 if err != nil { 2213 return fmt.Errorf("Post: %v", err) 2214 } 2215 got, err := io.ReadAll(res.Body) 2216 if err != nil { 2217 return fmt.Errorf("Body ReadAll: %v", err) 2218 } 2219 t.Logf("server response:\n%s", got) 2220 if !strings.Contains(string(got), deniedMsg) { 2221 // If we got an RST packet too early, we should have seen an error 2222 // from io.ReadAll, not a silently-truncated body. 2223 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) 2224 } 2225 return nil 2226 }) 2227 } 2228 2229 // From https://golang.org/issue/4454 , 2230 // "client fails to handle requests with no body and chunked encoding" 2231 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) } 2232 func testChunkedNoContent(t *testing.T, mode testMode) { 2233 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 2234 w.WriteHeader(StatusNoContent) 2235 })).ts 2236 2237 c := ts.Client() 2238 for _, closeBody := range []bool{true, false} { 2239 const n = 4 2240 for i := 1; i <= n; i++ { 2241 res, err := c.Get(ts.URL) 2242 if err != nil { 2243 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err) 2244 } else { 2245 if closeBody { 2246 res.Body.Close() 2247 } 2248 } 2249 } 2250 } 2251 } 2252 2253 func TestTransportConcurrency(t *testing.T) { 2254 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode}) 2255 } 2256 func testTransportConcurrency(t *testing.T, mode testMode) { 2257 // Not parallel: uses global test hooks. 2258 maxProcs, numReqs := 16, 500 2259 if testing.Short() { 2260 maxProcs, numReqs = 4, 50 2261 } 2262 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) 2263 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 2264 fmt.Fprintf(w, "%v", r.FormValue("echo")) 2265 })).ts 2266 2267 var wg sync.WaitGroup 2268 wg.Add(numReqs) 2269 2270 // Due to the Transport's "socket late binding" (see 2271 // idleConnCh in transport.go), the numReqs HTTP requests 2272 // below can finish with a dial still outstanding. To keep 2273 // the leak checker happy, keep track of pending dials and 2274 // wait for them to finish (and be closed or returned to the 2275 // idle pool) before we close idle connections. 2276 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) 2277 defer SetPendingDialHooks(nil, nil) 2278 2279 c := ts.Client() 2280 reqs := make(chan string) 2281 defer close(reqs) 2282 2283 for i := 0; i < maxProcs*2; i++ { 2284 go func() { 2285 for req := range reqs { 2286 res, err := c.Get(ts.URL + "/?echo=" + req) 2287 if err != nil { 2288 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") { 2289 // https://go.dev/issue/52168: this test was observed to fail with 2290 // ECONNRESET errors in Dial on various netbsd builders. 2291 t.Logf("error on req %s: %v", req, err) 2292 t.Logf("(see https://go.dev/issue/52168)") 2293 } else { 2294 t.Errorf("error on req %s: %v", req, err) 2295 } 2296 wg.Done() 2297 continue 2298 } 2299 all, err := io.ReadAll(res.Body) 2300 if err != nil { 2301 t.Errorf("read error on req %s: %v", req, err) 2302 } else if string(all) != req { 2303 t.Errorf("body of req %s = %q; want %q", req, all, req) 2304 } 2305 res.Body.Close() 2306 wg.Done() 2307 } 2308 }() 2309 } 2310 for i := 0; i < numReqs; i++ { 2311 reqs <- fmt.Sprintf("request-%d", i) 2312 } 2313 wg.Wait() 2314 } 2315 2316 func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) } 2317 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) { 2318 mux := NewServeMux() 2319 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { 2320 io.Copy(w, neverEnding('a')) 2321 }) 2322 ts := newClientServerTest(t, mode, mux).ts 2323 2324 connc := make(chan net.Conn, 1) 2325 c := ts.Client() 2326 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { 2327 conn, err := net.Dial(n, addr) 2328 if err != nil { 2329 return nil, err 2330 } 2331 select { 2332 case connc <- conn: 2333 default: 2334 } 2335 return conn, nil 2336 } 2337 2338 res, err := c.Get(ts.URL + "/get") 2339 if err != nil { 2340 t.Fatalf("Error issuing GET: %v", err) 2341 } 2342 defer res.Body.Close() 2343 2344 conn := <-connc 2345 conn.SetDeadline(time.Now().Add(1 * time.Millisecond)) 2346 _, err = io.Copy(io.Discard, res.Body) 2347 if err == nil { 2348 t.Errorf("Unexpected successful copy") 2349 } 2350 } 2351 2352 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { 2353 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode}) 2354 } 2355 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) { 2356 const debug = false 2357 mux := NewServeMux() 2358 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { 2359 io.Copy(w, neverEnding('a')) 2360 }) 2361 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) { 2362 defer r.Body.Close() 2363 io.Copy(io.Discard, r.Body) 2364 }) 2365 ts := newClientServerTest(t, mode, mux).ts 2366 timeout := 100 * time.Millisecond 2367 2368 c := ts.Client() 2369 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { 2370 conn, err := net.Dial(n, addr) 2371 if err != nil { 2372 return nil, err 2373 } 2374 conn.SetDeadline(time.Now().Add(timeout)) 2375 if debug { 2376 conn = NewLoggingConn("client", conn) 2377 } 2378 return conn, nil 2379 } 2380 2381 getFailed := false 2382 nRuns := 5 2383 if testing.Short() { 2384 nRuns = 1 2385 } 2386 for i := 0; i < nRuns; i++ { 2387 if debug { 2388 println("run", i+1, "of", nRuns) 2389 } 2390 sres, err := c.Get(ts.URL + "/get") 2391 if err != nil { 2392 if !getFailed { 2393 // Make the timeout longer, once. 2394 getFailed = true 2395 t.Logf("increasing timeout") 2396 i-- 2397 timeout *= 10 2398 continue 2399 } 2400 t.Errorf("Error issuing GET: %v", err) 2401 break 2402 } 2403 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body) 2404 _, err = c.Do(req) 2405 if err == nil { 2406 sres.Body.Close() 2407 t.Errorf("Unexpected successful PUT") 2408 break 2409 } 2410 sres.Body.Close() 2411 } 2412 if debug { 2413 println("tests complete; waiting for handlers to finish") 2414 } 2415 ts.Close() 2416 } 2417 2418 func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) } 2419 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) { 2420 if testing.Short() { 2421 t.Skip("skipping timeout test in -short mode") 2422 } 2423 2424 timeout := 2 * time.Millisecond 2425 retry := true 2426 for retry && !t.Failed() { 2427 var srvWG sync.WaitGroup 2428 inHandler := make(chan bool, 1) 2429 mux := NewServeMux() 2430 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) { 2431 inHandler <- true 2432 srvWG.Done() 2433 }) 2434 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { 2435 inHandler <- true 2436 <-r.Context().Done() 2437 srvWG.Done() 2438 }) 2439 ts := newClientServerTest(t, mode, mux).ts 2440 2441 c := ts.Client() 2442 c.Transport.(*Transport).ResponseHeaderTimeout = timeout 2443 2444 retry = false 2445 srvWG.Add(3) 2446 tests := []struct { 2447 path string 2448 wantTimeout bool 2449 }{ 2450 {path: "/fast"}, 2451 {path: "/slow", wantTimeout: true}, 2452 {path: "/fast"}, 2453 } 2454 for i, tt := range tests { 2455 req, _ := NewRequest("GET", ts.URL+tt.path, nil) 2456 req = req.WithT(t) 2457 res, err := c.Do(req) 2458 <-inHandler 2459 if err != nil { 2460 uerr, ok := err.(*url.Error) 2461 if !ok { 2462 t.Errorf("error is not a url.Error; got: %#v", err) 2463 continue 2464 } 2465 nerr, ok := uerr.Err.(net.Error) 2466 if !ok { 2467 t.Errorf("error does not satisfy net.Error interface; got: %#v", err) 2468 continue 2469 } 2470 if !nerr.Timeout() { 2471 t.Errorf("want timeout error; got: %q", nerr) 2472 continue 2473 } 2474 if !tt.wantTimeout { 2475 if !retry { 2476 // The timeout may be set too short. Retry with a longer one. 2477 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout) 2478 timeout *= 2 2479 retry = true 2480 } 2481 } 2482 if !strings.Contains(err.Error(), "timeout awaiting response headers") { 2483 t.Errorf("%d. unexpected error: %v", i, err) 2484 } 2485 continue 2486 } 2487 if tt.wantTimeout { 2488 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path) 2489 continue 2490 } 2491 if res.StatusCode != 200 { 2492 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode) 2493 } 2494 } 2495 2496 srvWG.Wait() 2497 ts.Close() 2498 } 2499 } 2500 2501 func TestTransportCancelRequest(t *testing.T) { 2502 run(t, testTransportCancelRequest, []testMode{http1Mode}) 2503 } 2504 func testTransportCancelRequest(t *testing.T, mode testMode) { 2505 if testing.Short() { 2506 t.Skip("skipping test in -short mode") 2507 } 2508 2509 const msg = "Hello" 2510 unblockc := make(chan bool) 2511 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 2512 io.WriteString(w, msg) 2513 w.(Flusher).Flush() // send headers and some body 2514 <-unblockc 2515 })).ts 2516 defer close(unblockc) 2517 2518 c := ts.Client() 2519 tr := c.Transport.(*Transport) 2520 2521 req, _ := NewRequest("GET", ts.URL, nil) 2522 res, err := c.Do(req) 2523 if err != nil { 2524 t.Fatal(err) 2525 } 2526 body := make([]byte, len(msg)) 2527 n, _ := io.ReadFull(res.Body, body) 2528 if n != len(body) || !bytes.Equal(body, []byte(msg)) { 2529 t.Errorf("Body = %q; want %q", body[:n], msg) 2530 } 2531 tr.CancelRequest(req) 2532 2533 tail, err := io.ReadAll(res.Body) 2534 res.Body.Close() 2535 if err != ExportErrRequestCanceled { 2536 t.Errorf("Body.Read error = %v; want errRequestCanceled", err) 2537 } else if len(tail) > 0 { 2538 t.Errorf("Spurious bytes from Body.Read: %q", tail) 2539 } 2540 2541 // Verify no outstanding requests after readLoop/writeLoop 2542 // goroutines shut down. 2543 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { 2544 n := tr.NumPendingRequestsForTesting() 2545 if n > 0 { 2546 if d > 0 { 2547 t.Logf("pending requests = %d after %v (want 0)", n, d) 2548 } 2549 return false 2550 } 2551 return true 2552 }) 2553 } 2554 2555 func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) { 2556 if testing.Short() { 2557 t.Skip("skipping test in -short mode") 2558 } 2559 unblockc := make(chan bool) 2560 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 2561 <-unblockc 2562 })).ts 2563 defer close(unblockc) 2564 2565 c := ts.Client() 2566 tr := c.Transport.(*Transport) 2567 2568 donec := make(chan bool) 2569 req, _ := NewRequest("GET", ts.URL, body) 2570 go func() { 2571 defer close(donec) 2572 c.Do(req) 2573 }() 2574 2575 unblockc <- true 2576 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { 2577 tr.CancelRequest(req) 2578 select { 2579 case <-donec: 2580 return true 2581 default: 2582 if d > 0 { 2583 t.Logf("Do of canceled request has not returned after %v", d) 2584 } 2585 return false 2586 } 2587 }) 2588 } 2589 2590 func TestTransportCancelRequestInDo(t *testing.T) { 2591 run(t, func(t *testing.T, mode testMode) { 2592 testTransportCancelRequestInDo(t, mode, nil) 2593 }, []testMode{http1Mode}) 2594 } 2595 2596 func TestTransportCancelRequestWithBodyInDo(t *testing.T) { 2597 run(t, func(t *testing.T, mode testMode) { 2598 testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0})) 2599 }, []testMode{http1Mode}) 2600 } 2601 2602 func TestTransportCancelRequestInDial(t *testing.T) { 2603 defer afterTest(t) 2604 if testing.Short() { 2605 t.Skip("skipping test in -short mode") 2606 } 2607 var logbuf strings.Builder 2608 eventLog := log.New(&logbuf, "", 0) 2609 2610 unblockDial := make(chan bool) 2611 defer close(unblockDial) 2612 2613 inDial := make(chan bool) 2614 tr := &Transport{ 2615 Dial: func(network, addr string) (net.Conn, error) { 2616 eventLog.Println("dial: blocking") 2617 if !<-inDial { 2618 return nil, errors.New("main Test goroutine exited") 2619 } 2620 <-unblockDial 2621 return nil, errors.New("nope") 2622 }, 2623 } 2624 cl := &Client{Transport: tr} 2625 gotres := make(chan bool) 2626 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil) 2627 go func() { 2628 _, err := cl.Do(req) 2629 eventLog.Printf("Get = %v", err) 2630 gotres <- true 2631 }() 2632 2633 inDial <- true 2634 2635 eventLog.Printf("canceling") 2636 tr.CancelRequest(req) 2637 tr.CancelRequest(req) // used to panic on second call 2638 2639 if d, ok := t.Deadline(); ok { 2640 // When the test's deadline is about to expire, log the pending events for 2641 // better debugging. 2642 timeout := time.Until(d) * 19 / 20 // Allow 5% for cleanup. 2643 timer := time.AfterFunc(timeout, func() { 2644 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String())) 2645 }) 2646 defer timer.Stop() 2647 } 2648 <-gotres 2649 2650 got := logbuf.String() 2651 want := `dial: blocking 2652 canceling 2653 Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection 2654 ` 2655 if got != want { 2656 t.Errorf("Got events:\n%s\nWant:\n%s", got, want) 2657 } 2658 } 2659 2660 func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) } 2661 func testCancelRequestWithChannel(t *testing.T, mode testMode) { 2662 if testing.Short() { 2663 t.Skip("skipping test in -short mode") 2664 } 2665 2666 const msg = "Hello" 2667 unblockc := make(chan struct{}) 2668 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 2669 io.WriteString(w, msg) 2670 w.(Flusher).Flush() // send headers and some body 2671 <-unblockc 2672 })).ts 2673 defer close(unblockc) 2674 2675 c := ts.Client() 2676 tr := c.Transport.(*Transport) 2677 2678 req, _ := NewRequest("GET", ts.URL, nil) 2679 cancel := make(chan struct{}) 2680 req.Cancel = cancel 2681 2682 res, err := c.Do(req) 2683 if err != nil { 2684 t.Fatal(err) 2685 } 2686 body := make([]byte, len(msg)) 2687 n, _ := io.ReadFull(res.Body, body) 2688 if n != len(body) || !bytes.Equal(body, []byte(msg)) { 2689 t.Errorf("Body = %q; want %q", body[:n], msg) 2690 } 2691 close(cancel) 2692 2693 tail, err := io.ReadAll(res.Body) 2694 res.Body.Close() 2695 if err != ExportErrRequestCanceled { 2696 t.Errorf("Body.Read error = %v; want errRequestCanceled", err) 2697 } else if len(tail) > 0 { 2698 t.Errorf("Spurious bytes from Body.Read: %q", tail) 2699 } 2700 2701 // Verify no outstanding requests after readLoop/writeLoop 2702 // goroutines shut down. 2703 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { 2704 n := tr.NumPendingRequestsForTesting() 2705 if n > 0 { 2706 if d > 0 { 2707 t.Logf("pending requests = %d after %v (want 0)", n, d) 2708 } 2709 return false 2710 } 2711 return true 2712 }) 2713 } 2714 2715 // Issue 51354 2716 func TestCancelRequestWithBodyWithChannel(t *testing.T) { 2717 run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode}) 2718 } 2719 func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) { 2720 if testing.Short() { 2721 t.Skip("skipping test in -short mode") 2722 } 2723 2724 const msg = "Hello" 2725 unblockc := make(chan struct{}) 2726 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 2727 io.WriteString(w, msg) 2728 w.(Flusher).Flush() // send headers and some body 2729 <-unblockc 2730 })).ts 2731 defer close(unblockc) 2732 2733 c := ts.Client() 2734 tr := c.Transport.(*Transport) 2735 2736 req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody")) 2737 cancel := make(chan struct{}) 2738 req.Cancel = cancel 2739 2740 res, err := c.Do(req) 2741 if err != nil { 2742 t.Fatal(err) 2743 } 2744 body := make([]byte, len(msg)) 2745 n, _ := io.ReadFull(res.Body, body) 2746 if n != len(body) || !bytes.Equal(body, []byte(msg)) { 2747 t.Errorf("Body = %q; want %q", body[:n], msg) 2748 } 2749 close(cancel) 2750 2751 tail, err := io.ReadAll(res.Body) 2752 res.Body.Close() 2753 if err != ExportErrRequestCanceled { 2754 t.Errorf("Body.Read error = %v; want errRequestCanceled", err) 2755 } else if len(tail) > 0 { 2756 t.Errorf("Spurious bytes from Body.Read: %q", tail) 2757 } 2758 2759 // Verify no outstanding requests after readLoop/writeLoop 2760 // goroutines shut down. 2761 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { 2762 n := tr.NumPendingRequestsForTesting() 2763 if n > 0 { 2764 if d > 0 { 2765 t.Logf("pending requests = %d after %v (want 0)", n, d) 2766 } 2767 return false 2768 } 2769 return true 2770 }) 2771 } 2772 2773 func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { 2774 run(t, func(t *testing.T, mode testMode) { 2775 testCancelRequestWithChannelBeforeDo(t, mode, false) 2776 }) 2777 } 2778 func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { 2779 run(t, func(t *testing.T, mode testMode) { 2780 testCancelRequestWithChannelBeforeDo(t, mode, true) 2781 }) 2782 } 2783 func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) { 2784 unblockc := make(chan bool) 2785 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 2786 <-unblockc 2787 })).ts 2788 defer close(unblockc) 2789 2790 c := ts.Client() 2791 2792 req, _ := NewRequest("GET", ts.URL, nil) 2793 if withCtx { 2794 ctx, cancel := context.WithCancel(context.Background()) 2795 cancel() 2796 req = req.WithContext(ctx) 2797 } else { 2798 ch := make(chan struct{}) 2799 req.Cancel = ch 2800 close(ch) 2801 } 2802 2803 _, err := c.Do(req) 2804 if ue, ok := err.(*url.Error); ok { 2805 err = ue.Err 2806 } 2807 if withCtx { 2808 if err != context.Canceled { 2809 t.Errorf("Do error = %v; want %v", err, context.Canceled) 2810 } 2811 } else { 2812 if err == nil || !strings.Contains(err.Error(), "canceled") { 2813 t.Errorf("Do error = %v; want cancellation", err) 2814 } 2815 } 2816 } 2817 2818 // Issue 11020. The returned error message should be errRequestCanceled 2819 func TestTransportCancelBeforeResponseHeaders(t *testing.T) { 2820 defer afterTest(t) 2821 2822 serverConnCh := make(chan net.Conn, 1) 2823 tr := &Transport{ 2824 Dial: func(network, addr string) (net.Conn, error) { 2825 cc, sc := net.Pipe() 2826 serverConnCh <- sc 2827 return cc, nil 2828 }, 2829 } 2830 defer tr.CloseIdleConnections() 2831 errc := make(chan error, 1) 2832 req, _ := NewRequest("GET", "http://example.com/", nil) 2833 go func() { 2834 _, err := tr.RoundTrip(req) 2835 errc <- err 2836 }() 2837 2838 sc := <-serverConnCh 2839 verb := make([]byte, 3) 2840 if _, err := io.ReadFull(sc, verb); err != nil { 2841 t.Errorf("Error reading HTTP verb from server: %v", err) 2842 } 2843 if string(verb) != "GET" { 2844 t.Errorf("server received %q; want GET", verb) 2845 } 2846 defer sc.Close() 2847 2848 tr.CancelRequest(req) 2849 2850 err := <-errc 2851 if err == nil { 2852 t.Fatalf("unexpected success from RoundTrip") 2853 } 2854 if err != ExportErrRequestCanceled { 2855 t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err) 2856 } 2857 } 2858 2859 // golang.org/issue/3672 -- Client can't close HTTP stream 2860 // Calling Close on a Response.Body used to just read until EOF. 2861 // Now it actually closes the TCP connection. 2862 func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) } 2863 func testTransportCloseResponseBody(t *testing.T, mode testMode) { 2864 writeErr := make(chan error, 1) 2865 msg := []byte("young\n") 2866 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 2867 for { 2868 _, err := w.Write(msg) 2869 if err != nil { 2870 writeErr <- err 2871 return 2872 } 2873 w.(Flusher).Flush() 2874 } 2875 })).ts 2876 2877 c := ts.Client() 2878 tr := c.Transport.(*Transport) 2879 2880 req, _ := NewRequest("GET", ts.URL, nil) 2881 defer tr.CancelRequest(req) 2882 2883 res, err := c.Do(req) 2884 if err != nil { 2885 t.Fatal(err) 2886 } 2887 2888 const repeats = 3 2889 buf := make([]byte, len(msg)*repeats) 2890 want := bytes.Repeat(msg, repeats) 2891 2892 _, err = io.ReadFull(res.Body, buf) 2893 if err != nil { 2894 t.Fatal(err) 2895 } 2896 if !bytes.Equal(buf, want) { 2897 t.Fatalf("read %q; want %q", buf, want) 2898 } 2899 2900 if err := res.Body.Close(); err != nil { 2901 t.Errorf("Close = %v", err) 2902 } 2903 2904 if err := <-writeErr; err == nil { 2905 t.Errorf("expected non-nil write error") 2906 } 2907 } 2908 2909 type fooProto struct{} 2910 2911 func (fooProto) RoundTrip(req *Request) (*Response, error) { 2912 res := &Response{ 2913 Status: "200 OK", 2914 StatusCode: 200, 2915 Header: make(Header), 2916 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())), 2917 } 2918 return res, nil 2919 } 2920 2921 func TestTransportAltProto(t *testing.T) { 2922 defer afterTest(t) 2923 tr := &Transport{} 2924 c := &Client{Transport: tr} 2925 tr.RegisterProtocol("foo", fooProto{}) 2926 res, err := c.Get("foo://bar.com/path") 2927 if err != nil { 2928 t.Fatal(err) 2929 } 2930 bodyb, err := io.ReadAll(res.Body) 2931 if err != nil { 2932 t.Fatal(err) 2933 } 2934 body := string(bodyb) 2935 if e := "You wanted foo://bar.com/path"; body != e { 2936 t.Errorf("got response %q, want %q", body, e) 2937 } 2938 } 2939 2940 func TestTransportNoHost(t *testing.T) { 2941 defer afterTest(t) 2942 tr := &Transport{} 2943 _, err := tr.RoundTrip(&Request{ 2944 Header: make(Header), 2945 URL: &url.URL{ 2946 Scheme: "http", 2947 }, 2948 }) 2949 want := "http: no Host in request URL" 2950 if got := fmt.Sprint(err); got != want { 2951 t.Errorf("error = %v; want %q", err, want) 2952 } 2953 } 2954 2955 // Issue 13311 2956 func TestTransportEmptyMethod(t *testing.T) { 2957 req, _ := NewRequest("GET", "http://foo.com/", nil) 2958 req.Method = "" // docs say "For client requests an empty string means GET" 2959 got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport 2960 if err != nil { 2961 t.Fatal(err) 2962 } 2963 if !strings.Contains(string(got), "GET ") { 2964 t.Fatalf("expected substring 'GET '; got: %s", got) 2965 } 2966 } 2967 2968 func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) } 2969 func testTransportSocketLateBinding(t *testing.T, mode testMode) { 2970 mux := NewServeMux() 2971 fooGate := make(chan bool, 1) 2972 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) { 2973 w.Header().Set("foo-ipport", r.RemoteAddr) 2974 w.(Flusher).Flush() 2975 <-fooGate 2976 }) 2977 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) { 2978 w.Header().Set("bar-ipport", r.RemoteAddr) 2979 }) 2980 ts := newClientServerTest(t, mode, mux).ts 2981 2982 dialGate := make(chan bool, 1) 2983 dialing := make(chan bool) 2984 c := ts.Client() 2985 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { 2986 for { 2987 select { 2988 case ok := <-dialGate: 2989 if !ok { 2990 return nil, errors.New("manually closed") 2991 } 2992 return net.Dial(n, addr) 2993 case dialing <- true: 2994 } 2995 } 2996 } 2997 defer close(dialGate) 2998 2999 dialGate <- true // only allow one dial 3000 fooRes, err := c.Get(ts.URL + "/foo") 3001 if err != nil { 3002 t.Fatal(err) 3003 } 3004 fooAddr := fooRes.Header.Get("foo-ipport") 3005 if fooAddr == "" { 3006 t.Fatal("No addr on /foo request") 3007 } 3008 3009 fooDone := make(chan struct{}) 3010 go func() { 3011 // We know that the foo Dial completed and reached the handler because we 3012 // read its header. Wait for the bar request to block in Dial, then 3013 // let the foo response finish so we can use its connection for /bar. 3014 3015 if mode == http2Mode { 3016 // In HTTP/2 mode, the second Dial won't happen because the protocol 3017 // multiplexes the streams by default. Just sleep for an arbitrary time; 3018 // the test should pass regardless of how far the bar request gets by this 3019 // point. 3020 select { 3021 case <-dialing: 3022 t.Errorf("unexpected second Dial in HTTP/2 mode") 3023 case <-time.After(10 * time.Millisecond): 3024 } 3025 } else { 3026 <-dialing 3027 } 3028 fooGate <- true 3029 io.Copy(io.Discard, fooRes.Body) 3030 fooRes.Body.Close() 3031 close(fooDone) 3032 }() 3033 defer func() { 3034 <-fooDone 3035 }() 3036 3037 barRes, err := c.Get(ts.URL + "/bar") 3038 if err != nil { 3039 t.Fatal(err) 3040 } 3041 barAddr := barRes.Header.Get("bar-ipport") 3042 if barAddr != fooAddr { 3043 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) 3044 } 3045 barRes.Body.Close() 3046 } 3047 3048 // Issue 2184 3049 func TestTransportReading100Continue(t *testing.T) { 3050 defer afterTest(t) 3051 3052 const numReqs = 5 3053 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) } 3054 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) } 3055 3056 send100Response := func(w *io.PipeWriter, r *io.PipeReader) { 3057 defer w.Close() 3058 defer r.Close() 3059 br := bufio.NewReader(r) 3060 n := 0 3061 for { 3062 n++ 3063 req, err := ReadRequest(br) 3064 if err == io.EOF { 3065 return 3066 } 3067 if err != nil { 3068 t.Error(err) 3069 return 3070 } 3071 slurp, err := io.ReadAll(req.Body) 3072 if err != nil { 3073 t.Errorf("Server request body slurp: %v", err) 3074 return 3075 } 3076 id := req.Header.Get("Request-Id") 3077 resCode := req.Header.Get("X-Want-Response-Code") 3078 if resCode == "" { 3079 resCode = "100 Continue" 3080 if string(slurp) != reqBody(n) { 3081 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n)) 3082 } 3083 } 3084 body := fmt.Sprintf("Response number %d", n) 3085 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s 3086 Date: Thu, 28 Feb 2013 17:55:41 GMT 3087 3088 HTTP/1.1 200 OK 3089 Content-Type: text/html 3090 Echo-Request-Id: %s 3091 Content-Length: %d 3092 3093 %s`, resCode, id, len(body), body), "\n", "\r\n", -1)) 3094 w.Write(v) 3095 if id == reqID(numReqs) { 3096 return 3097 } 3098 } 3099 3100 } 3101 3102 tr := &Transport{ 3103 Dial: func(n, addr string) (net.Conn, error) { 3104 sr, sw := io.Pipe() // server read/write 3105 cr, cw := io.Pipe() // client read/write 3106 conn := &rwTestConn{ 3107 Reader: cr, 3108 Writer: sw, 3109 closeFunc: func() error { 3110 sw.Close() 3111 cw.Close() 3112 return nil 3113 }, 3114 } 3115 go send100Response(cw, sr) 3116 return conn, nil 3117 }, 3118 DisableKeepAlives: false, 3119 } 3120 defer tr.CloseIdleConnections() 3121 c := &Client{Transport: tr} 3122 3123 testResponse := func(req *Request, name string, wantCode int) { 3124 t.Helper() 3125 res, err := c.Do(req) 3126 if err != nil { 3127 t.Fatalf("%s: Do: %v", name, err) 3128 } 3129 if res.StatusCode != wantCode { 3130 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode) 3131 } 3132 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { 3133 t.Errorf("%s: response id %q != request id %q", name, idBack, id) 3134 } 3135 _, err = io.ReadAll(res.Body) 3136 if err != nil { 3137 t.Fatalf("%s: Slurp error: %v", name, err) 3138 } 3139 } 3140 3141 // Few 100 responses, making sure we're not off-by-one. 3142 for i := 1; i <= numReqs; i++ { 3143 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i))) 3144 req.Header.Set("Request-Id", reqID(i)) 3145 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) 3146 } 3147 } 3148 3149 // Issue 17739: the HTTP client must ignore any unknown 1xx 3150 // informational responses before the actual response. 3151 func TestTransportIgnore1xxResponses(t *testing.T) { 3152 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode}) 3153 } 3154 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) { 3155 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3156 conn, buf, _ := w.(Hijacker).Hijack() 3157 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello")) 3158 buf.Flush() 3159 conn.Close() 3160 })) 3161 cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway 3162 3163 var got strings.Builder 3164 3165 req, _ := NewRequest("GET", cst.ts.URL, nil) 3166 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ 3167 Got1xxResponse: func(code int, header textproto.MIMEHeader) error { 3168 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header) 3169 return nil 3170 }, 3171 })) 3172 res, err := cst.c.Do(req) 3173 if err != nil { 3174 t.Fatal(err) 3175 } 3176 defer res.Body.Close() 3177 3178 res.Write(&got) 3179 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello" 3180 if got.String() != want { 3181 t.Errorf(" got: %q\nwant: %q\n", got.String(), want) 3182 } 3183 } 3184 3185 func TestTransportLimits1xxResponses(t *testing.T) { 3186 run(t, testTransportLimits1xxResponses, []testMode{http1Mode}) 3187 } 3188 func testTransportLimits1xxResponses(t *testing.T, mode testMode) { 3189 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3190 conn, buf, _ := w.(Hijacker).Hijack() 3191 for i := 0; i < 10; i++ { 3192 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) 3193 } 3194 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) 3195 buf.Flush() 3196 conn.Close() 3197 })) 3198 cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway 3199 3200 res, err := cst.c.Get(cst.ts.URL) 3201 if res != nil { 3202 defer res.Body.Close() 3203 } 3204 got := fmt.Sprint(err) 3205 wantSub := "too many 1xx informational responses" 3206 if !strings.Contains(got, wantSub) { 3207 t.Errorf("Get error = %v; want substring %q", err, wantSub) 3208 } 3209 } 3210 3211 // Issue 26161: the HTTP client must treat 101 responses 3212 // as the final response. 3213 func TestTransportTreat101Terminal(t *testing.T) { 3214 run(t, testTransportTreat101Terminal, []testMode{http1Mode}) 3215 } 3216 func testTransportTreat101Terminal(t *testing.T, mode testMode) { 3217 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3218 conn, buf, _ := w.(Hijacker).Hijack() 3219 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n")) 3220 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) 3221 buf.Flush() 3222 conn.Close() 3223 })) 3224 res, err := cst.c.Get(cst.ts.URL) 3225 if err != nil { 3226 t.Fatal(err) 3227 } 3228 defer res.Body.Close() 3229 if res.StatusCode != StatusSwitchingProtocols { 3230 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode) 3231 } 3232 } 3233 3234 type proxyFromEnvTest struct { 3235 req string // URL to fetch; blank means "http://example.com" 3236 3237 env string // HTTP_PROXY 3238 httpsenv string // HTTPS_PROXY 3239 noenv string // NO_PROXY 3240 reqmeth string // REQUEST_METHOD 3241 3242 want string 3243 wanterr error 3244 } 3245 3246 func (t proxyFromEnvTest) String() string { 3247 var buf strings.Builder 3248 space := func() { 3249 if buf.Len() > 0 { 3250 buf.WriteByte(' ') 3251 } 3252 } 3253 if t.env != "" { 3254 fmt.Fprintf(&buf, "http_proxy=%q", t.env) 3255 } 3256 if t.httpsenv != "" { 3257 space() 3258 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv) 3259 } 3260 if t.noenv != "" { 3261 space() 3262 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv) 3263 } 3264 if t.reqmeth != "" { 3265 space() 3266 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth) 3267 } 3268 req := "http://example.com" 3269 if t.req != "" { 3270 req = t.req 3271 } 3272 space() 3273 fmt.Fprintf(&buf, "req=%q", req) 3274 return strings.TrimSpace(buf.String()) 3275 } 3276 3277 var proxyFromEnvTests = []proxyFromEnvTest{ 3278 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"}, 3279 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"}, 3280 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"}, 3281 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"}, 3282 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"}, 3283 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"}, 3284 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"}, 3285 {env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"}, 3286 3287 // Don't use secure for http 3288 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"}, 3289 // Use secure for https. 3290 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"}, 3291 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"}, 3292 3293 // Issue 16405: don't use HTTP_PROXY in a CGI environment, 3294 // where HTTP_PROXY can be attacker-controlled. 3295 {env: "http://10.1.2.3:8080", reqmeth: "POST", 3296 want: "<nil>", 3297 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")}, 3298 3299 {want: "<nil>"}, 3300 3301 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"}, 3302 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, 3303 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, 3304 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"}, 3305 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, 3306 } 3307 3308 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) { 3309 t.Helper() 3310 reqURL := tt.req 3311 if reqURL == "" { 3312 reqURL = "http://example.com" 3313 } 3314 req, _ := NewRequest("GET", reqURL, nil) 3315 url, err := proxyForRequest(req) 3316 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e { 3317 t.Errorf("%v: got error = %q, want %q", tt, g, e) 3318 return 3319 } 3320 if got := fmt.Sprintf("%s", url); got != tt.want { 3321 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want) 3322 } 3323 } 3324 3325 func TestProxyFromEnvironment(t *testing.T) { 3326 ResetProxyEnv() 3327 defer ResetProxyEnv() 3328 for _, tt := range proxyFromEnvTests { 3329 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) { 3330 os.Setenv("HTTP_PROXY", tt.env) 3331 os.Setenv("HTTPS_PROXY", tt.httpsenv) 3332 os.Setenv("NO_PROXY", tt.noenv) 3333 os.Setenv("REQUEST_METHOD", tt.reqmeth) 3334 ResetCachedEnvironment() 3335 return ProxyFromEnvironment(req) 3336 }) 3337 } 3338 } 3339 3340 func TestProxyFromEnvironmentLowerCase(t *testing.T) { 3341 ResetProxyEnv() 3342 defer ResetProxyEnv() 3343 for _, tt := range proxyFromEnvTests { 3344 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) { 3345 os.Setenv("http_proxy", tt.env) 3346 os.Setenv("https_proxy", tt.httpsenv) 3347 os.Setenv("no_proxy", tt.noenv) 3348 os.Setenv("REQUEST_METHOD", tt.reqmeth) 3349 ResetCachedEnvironment() 3350 return ProxyFromEnvironment(req) 3351 }) 3352 } 3353 } 3354 3355 func TestIdleConnChannelLeak(t *testing.T) { 3356 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel) 3357 } 3358 func testIdleConnChannelLeak(t *testing.T, mode testMode) { 3359 // Not parallel: uses global test hooks. 3360 var mu sync.Mutex 3361 var n int 3362 3363 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3364 mu.Lock() 3365 n++ 3366 mu.Unlock() 3367 })).ts 3368 3369 const nReqs = 5 3370 didRead := make(chan bool, nReqs) 3371 SetReadLoopBeforeNextReadHook(func() { didRead <- true }) 3372 defer SetReadLoopBeforeNextReadHook(nil) 3373 3374 c := ts.Client() 3375 tr := c.Transport.(*Transport) 3376 tr.Dial = func(netw, addr string) (net.Conn, error) { 3377 return net.Dial(netw, ts.Listener.Addr().String()) 3378 } 3379 3380 // First, without keep-alives. 3381 for _, disableKeep := range []bool{true, false} { 3382 tr.DisableKeepAlives = disableKeep 3383 for i := 0; i < nReqs; i++ { 3384 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i)) 3385 if err != nil { 3386 t.Fatal(err) 3387 } 3388 // Note: no res.Body.Close is needed here, since the 3389 // response Content-Length is zero. Perhaps the test 3390 // should be more explicit and use a HEAD, but tests 3391 // elsewhere guarantee that zero byte responses generate 3392 // a "Content-Length: 0" instead of chunking. 3393 } 3394 3395 // At this point, each of the 5 Transport.readLoop goroutines 3396 // are scheduling noting that there are no response bodies (see 3397 // earlier comment), and are then calling putIdleConn, which 3398 // decrements this count. Usually that happens quickly, which is 3399 // why this test has seemed to work for ages. But it's still 3400 // racey: we have wait for them to finish first. See Issue 10427 3401 for i := 0; i < nReqs; i++ { 3402 <-didRead 3403 } 3404 3405 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 { 3406 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got) 3407 } 3408 } 3409 } 3410 3411 // Verify the status quo: that the Client.Post function coerces its 3412 // body into a ReadCloser if it's a Closer, and that the Transport 3413 // then closes it. 3414 func TestTransportClosesRequestBody(t *testing.T) { 3415 run(t, testTransportClosesRequestBody, []testMode{http1Mode}) 3416 } 3417 func testTransportClosesRequestBody(t *testing.T, mode testMode) { 3418 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3419 io.Copy(io.Discard, r.Body) 3420 })).ts 3421 3422 c := ts.Client() 3423 3424 closes := 0 3425 3426 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) 3427 if err != nil { 3428 t.Fatal(err) 3429 } 3430 res.Body.Close() 3431 if closes != 1 { 3432 t.Errorf("closes = %d; want 1", closes) 3433 } 3434 } 3435 3436 func TestTransportTLSHandshakeTimeout(t *testing.T) { 3437 defer afterTest(t) 3438 if testing.Short() { 3439 t.Skip("skipping in short mode") 3440 } 3441 ln := newLocalListener(t) 3442 defer ln.Close() 3443 testdonec := make(chan struct{}) 3444 defer close(testdonec) 3445 3446 go func() { 3447 c, err := ln.Accept() 3448 if err != nil { 3449 t.Error(err) 3450 return 3451 } 3452 <-testdonec 3453 c.Close() 3454 }() 3455 3456 tr := &Transport{ 3457 Dial: func(_, _ string) (net.Conn, error) { 3458 return net.Dial("tcp", ln.Addr().String()) 3459 }, 3460 TLSHandshakeTimeout: 250 * time.Millisecond, 3461 } 3462 cl := &Client{Transport: tr} 3463 _, err := cl.Get("https://dummy.tld/") 3464 if err == nil { 3465 t.Error("expected error") 3466 return 3467 } 3468 ue, ok := err.(*url.Error) 3469 if !ok { 3470 t.Errorf("expected url.Error; got %#v", err) 3471 return 3472 } 3473 ne, ok := ue.Err.(net.Error) 3474 if !ok { 3475 t.Errorf("expected net.Error; got %#v", err) 3476 return 3477 } 3478 if !ne.Timeout() { 3479 t.Errorf("expected timeout error; got %v", err) 3480 } 3481 if !strings.Contains(err.Error(), "handshake timeout") { 3482 t.Errorf("expected 'handshake timeout' in error; got %v", err) 3483 } 3484 } 3485 3486 // Trying to repro golang.org/issue/3514 3487 func TestTLSServerClosesConnection(t *testing.T) { 3488 run(t, testTLSServerClosesConnection, []testMode{https1Mode}) 3489 } 3490 func testTLSServerClosesConnection(t *testing.T, mode testMode) { 3491 closedc := make(chan bool, 1) 3492 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3493 if strings.Contains(r.URL.Path, "/keep-alive-then-die") { 3494 conn, _, _ := w.(Hijacker).Hijack() 3495 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) 3496 conn.Close() 3497 closedc <- true 3498 return 3499 } 3500 fmt.Fprintf(w, "hello") 3501 })).ts 3502 3503 c := ts.Client() 3504 tr := c.Transport.(*Transport) 3505 3506 var nSuccess = 0 3507 var errs []error 3508 const trials = 20 3509 for i := 0; i < trials; i++ { 3510 tr.CloseIdleConnections() 3511 res, err := c.Get(ts.URL + "/keep-alive-then-die") 3512 if err != nil { 3513 t.Fatal(err) 3514 } 3515 <-closedc 3516 slurp, err := io.ReadAll(res.Body) 3517 if err != nil { 3518 t.Fatal(err) 3519 } 3520 if string(slurp) != "foo" { 3521 t.Errorf("Got %q, want foo", slurp) 3522 } 3523 3524 // Now try again and see if we successfully 3525 // pick a new connection. 3526 res, err = c.Get(ts.URL + "/") 3527 if err != nil { 3528 errs = append(errs, err) 3529 continue 3530 } 3531 slurp, err = io.ReadAll(res.Body) 3532 if err != nil { 3533 errs = append(errs, err) 3534 continue 3535 } 3536 nSuccess++ 3537 } 3538 if nSuccess > 0 { 3539 t.Logf("successes = %d of %d", nSuccess, trials) 3540 } else { 3541 t.Errorf("All runs failed:") 3542 } 3543 for _, err := range errs { 3544 t.Logf(" err: %v", err) 3545 } 3546 } 3547 3548 // byteFromChanReader is an io.Reader that reads a single byte at a 3549 // time from the channel. When the channel is closed, the reader 3550 // returns io.EOF. 3551 type byteFromChanReader chan byte 3552 3553 func (c byteFromChanReader) Read(p []byte) (n int, err error) { 3554 if len(p) == 0 { 3555 return 3556 } 3557 b, ok := <-c 3558 if !ok { 3559 return 0, io.EOF 3560 } 3561 p[0] = b 3562 return 1, nil 3563 } 3564 3565 // Verifies that the Transport doesn't reuse a connection in the case 3566 // where the server replies before the request has been fully 3567 // written. We still honor that reply (see TestIssue3595), but don't 3568 // send future requests on the connection because it's then in a 3569 // questionable state. 3570 // golang.org/issue/7569 3571 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { 3572 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel) 3573 } 3574 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) { 3575 defer func(d time.Duration) { 3576 *MaxWriteWaitBeforeConnReuse = d 3577 }(*MaxWriteWaitBeforeConnReuse) 3578 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond 3579 var sconn struct { 3580 sync.Mutex 3581 c net.Conn 3582 } 3583 var getOkay bool 3584 var copying sync.WaitGroup 3585 closeConn := func() { 3586 sconn.Lock() 3587 defer sconn.Unlock() 3588 if sconn.c != nil { 3589 sconn.c.Close() 3590 sconn.c = nil 3591 if !getOkay { 3592 t.Logf("Closed server connection") 3593 } 3594 } 3595 } 3596 defer func() { 3597 closeConn() 3598 copying.Wait() 3599 }() 3600 3601 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3602 if r.Method == "GET" { 3603 io.WriteString(w, "bar") 3604 return 3605 } 3606 conn, _, _ := w.(Hijacker).Hijack() 3607 sconn.Lock() 3608 sconn.c = conn 3609 sconn.Unlock() 3610 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive 3611 3612 copying.Add(1) 3613 go func() { 3614 io.Copy(io.Discard, conn) 3615 copying.Done() 3616 }() 3617 })).ts 3618 c := ts.Client() 3619 3620 const bodySize = 256 << 10 3621 finalBit := make(byteFromChanReader, 1) 3622 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) 3623 req.ContentLength = bodySize 3624 res, err := c.Do(req) 3625 if err := wantBody(res, err, "foo"); err != nil { 3626 t.Errorf("POST response: %v", err) 3627 } 3628 3629 res, err = c.Get(ts.URL) 3630 if err := wantBody(res, err, "bar"); err != nil { 3631 t.Errorf("GET response: %v", err) 3632 return 3633 } 3634 getOkay = true // suppress test noise 3635 finalBit <- 'x' // unblock the writeloop of the first Post 3636 close(finalBit) 3637 } 3638 3639 // Tests that we don't leak Transport persistConn.readLoop goroutines 3640 // when a server hangs up immediately after saying it would keep-alive. 3641 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) } 3642 func testTransportIssue10457(t *testing.T, mode testMode) { 3643 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3644 // Send a response with no body, keep-alive 3645 // (implicit), and then lie and immediately close the 3646 // connection. This forces the Transport's readLoop to 3647 // immediately Peek an io.EOF and get to the point 3648 // that used to hang. 3649 conn, _, _ := w.(Hijacker).Hijack() 3650 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive 3651 conn.Close() 3652 })).ts 3653 c := ts.Client() 3654 3655 res, err := c.Get(ts.URL) 3656 if err != nil { 3657 t.Fatalf("Get: %v", err) 3658 } 3659 defer res.Body.Close() 3660 3661 // Just a sanity check that we at least get the response. The real 3662 // test here is that the "defer afterTest" above doesn't find any 3663 // leaked goroutines. 3664 if got, want := res.Header.Get("Foo"), "Bar"; got != want { 3665 t.Errorf("Foo header = %q; want %q", got, want) 3666 } 3667 } 3668 3669 type closerFunc func() error 3670 3671 func (f closerFunc) Close() error { return f() } 3672 3673 type writerFuncConn struct { 3674 net.Conn 3675 write func(p []byte) (n int, err error) 3676 } 3677 3678 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) } 3679 3680 // Issues 4677, 18241, and 17844. If we try to reuse a connection that the 3681 // server is in the process of closing, we may end up successfully writing out 3682 // our request (or a portion of our request) only to find a connection error 3683 // when we try to read from (or finish writing to) the socket. 3684 // 3685 // NOTE: we resend a request only if: 3686 // - we reused a keep-alive connection 3687 // - we haven't yet received any header data 3688 // - either we wrote no bytes to the server, or the request is idempotent 3689 // 3690 // This automatically prevents an infinite resend loop because we'll run out of 3691 // the cached keep-alive connections eventually. 3692 func TestRetryRequestsOnError(t *testing.T) { 3693 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode}) 3694 } 3695 func testRetryRequestsOnError(t *testing.T, mode testMode) { 3696 newRequest := func(method, urlStr string, body io.Reader) *Request { 3697 req, err := NewRequest(method, urlStr, body) 3698 if err != nil { 3699 t.Fatal(err) 3700 } 3701 return req 3702 } 3703 3704 testCases := []struct { 3705 name string 3706 failureN int 3707 failureErr error 3708 // Note that we can't just re-use the Request object across calls to c.Do 3709 // because we need to rewind Body between calls. (GetBody is only used to 3710 // rewind Body on failure and redirects, not just because it's done.) 3711 req func() *Request 3712 reqString string 3713 }{ 3714 { 3715 name: "IdempotentNoBodySomeWritten", 3716 // Believe that we've written some bytes to the server, so we know we're 3717 // not just in the "retry when no bytes sent" case". 3718 failureN: 1, 3719 // Use the specific error that shouldRetryRequest looks for with idempotent requests. 3720 failureErr: ExportErrServerClosedIdle, 3721 req: func() *Request { 3722 return newRequest("GET", "http://fake.golang", nil) 3723 }, 3724 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`, 3725 }, 3726 { 3727 name: "IdempotentGetBodySomeWritten", 3728 // Believe that we've written some bytes to the server, so we know we're 3729 // not just in the "retry when no bytes sent" case". 3730 failureN: 1, 3731 // Use the specific error that shouldRetryRequest looks for with idempotent requests. 3732 failureErr: ExportErrServerClosedIdle, 3733 req: func() *Request { 3734 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n")) 3735 }, 3736 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`, 3737 }, 3738 { 3739 name: "NothingWrittenNoBody", 3740 // It's key that we return 0 here -- that's what enables Transport to know 3741 // that nothing was written, even though this is a non-idempotent request. 3742 failureN: 0, 3743 failureErr: errors.New("second write fails"), 3744 req: func() *Request { 3745 return newRequest("DELETE", "http://fake.golang", nil) 3746 }, 3747 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`, 3748 }, 3749 { 3750 name: "NothingWrittenGetBody", 3751 // It's key that we return 0 here -- that's what enables Transport to know 3752 // that nothing was written, even though this is a non-idempotent request. 3753 failureN: 0, 3754 failureErr: errors.New("second write fails"), 3755 // Note that NewRequest will set up GetBody for strings.Reader, which is 3756 // required for the retry to occur 3757 req: func() *Request { 3758 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n")) 3759 }, 3760 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`, 3761 }, 3762 } 3763 3764 for _, tc := range testCases { 3765 t.Run(tc.name, func(t *testing.T) { 3766 var ( 3767 mu sync.Mutex 3768 logbuf strings.Builder 3769 ) 3770 logf := func(format string, args ...any) { 3771 mu.Lock() 3772 defer mu.Unlock() 3773 fmt.Fprintf(&logbuf, format, args...) 3774 logbuf.WriteByte('\n') 3775 } 3776 3777 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3778 logf("Handler") 3779 w.Header().Set("X-Status", "ok") 3780 })).ts 3781 3782 var writeNumAtomic int32 3783 c := ts.Client() 3784 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) { 3785 logf("Dial") 3786 c, err := net.Dial(network, ts.Listener.Addr().String()) 3787 if err != nil { 3788 logf("Dial error: %v", err) 3789 return nil, err 3790 } 3791 return &writerFuncConn{ 3792 Conn: c, 3793 write: func(p []byte) (n int, err error) { 3794 if atomic.AddInt32(&writeNumAtomic, 1) == 2 { 3795 logf("intentional write failure") 3796 return tc.failureN, tc.failureErr 3797 } 3798 logf("Write(%q)", p) 3799 return c.Write(p) 3800 }, 3801 }, nil 3802 } 3803 3804 SetRoundTripRetried(func() { 3805 logf("Retried.") 3806 }) 3807 defer SetRoundTripRetried(nil) 3808 3809 for i := 0; i < 3; i++ { 3810 t0 := time.Now() 3811 req := tc.req() 3812 res, err := c.Do(req) 3813 if err != nil { 3814 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 { 3815 mu.Lock() 3816 got := logbuf.String() 3817 mu.Unlock() 3818 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got) 3819 } 3820 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse) 3821 } 3822 res.Body.Close() 3823 if res.Request != req { 3824 t.Errorf("Response.Request != original request; want identical Request") 3825 } 3826 } 3827 3828 mu.Lock() 3829 got := logbuf.String() 3830 mu.Unlock() 3831 want := fmt.Sprintf(`Dial 3832 Write("%s") 3833 Handler 3834 intentional write failure 3835 Retried. 3836 Dial 3837 Write("%s") 3838 Handler 3839 Write("%s") 3840 Handler 3841 `, tc.reqString, tc.reqString, tc.reqString) 3842 if got != want { 3843 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want) 3844 } 3845 }) 3846 } 3847 } 3848 3849 // Issue 6981 3850 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) } 3851 func testTransportClosesBodyOnError(t *testing.T, mode testMode) { 3852 readBody := make(chan error, 1) 3853 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3854 _, err := io.ReadAll(r.Body) 3855 readBody <- err 3856 })).ts 3857 c := ts.Client() 3858 fakeErr := errors.New("fake error") 3859 didClose := make(chan bool, 1) 3860 req, _ := NewRequest("POST", ts.URL, struct { 3861 io.Reader 3862 io.Closer 3863 }{ 3864 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)), 3865 closerFunc(func() error { 3866 select { 3867 case didClose <- true: 3868 default: 3869 } 3870 return nil 3871 }), 3872 }) 3873 res, err := c.Do(req) 3874 if res != nil { 3875 defer res.Body.Close() 3876 } 3877 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) { 3878 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error()) 3879 } 3880 if err := <-readBody; err == nil { 3881 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") 3882 } 3883 select { 3884 case <-didClose: 3885 default: 3886 t.Errorf("didn't see Body.Close") 3887 } 3888 } 3889 3890 func TestTransportDialTLS(t *testing.T) { 3891 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode}) 3892 } 3893 func testTransportDialTLS(t *testing.T, mode testMode) { 3894 var mu sync.Mutex // guards following 3895 var gotReq, didDial bool 3896 3897 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3898 mu.Lock() 3899 gotReq = true 3900 mu.Unlock() 3901 })).ts 3902 c := ts.Client() 3903 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) { 3904 mu.Lock() 3905 didDial = true 3906 mu.Unlock() 3907 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) 3908 if err != nil { 3909 return nil, err 3910 } 3911 return c, c.Handshake() 3912 } 3913 3914 res, err := c.Get(ts.URL) 3915 if err != nil { 3916 t.Fatal(err) 3917 } 3918 res.Body.Close() 3919 mu.Lock() 3920 if !gotReq { 3921 t.Error("didn't get request") 3922 } 3923 if !didDial { 3924 t.Error("didn't use dial hook") 3925 } 3926 } 3927 3928 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) } 3929 func testTransportDialContext(t *testing.T, mode testMode) { 3930 var mu sync.Mutex // guards following 3931 var gotReq bool 3932 var receivedContext context.Context 3933 3934 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3935 mu.Lock() 3936 gotReq = true 3937 mu.Unlock() 3938 })).ts 3939 c := ts.Client() 3940 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { 3941 mu.Lock() 3942 receivedContext = ctx 3943 mu.Unlock() 3944 return net.Dial(netw, addr) 3945 } 3946 3947 req, err := NewRequest("GET", ts.URL, nil) 3948 if err != nil { 3949 t.Fatal(err) 3950 } 3951 ctx := context.WithValue(context.Background(), "some-key", "some-value") 3952 res, err := c.Do(req.WithContext(ctx)) 3953 if err != nil { 3954 t.Fatal(err) 3955 } 3956 res.Body.Close() 3957 mu.Lock() 3958 if !gotReq { 3959 t.Error("didn't get request") 3960 } 3961 if receivedContext != ctx { 3962 t.Error("didn't receive correct context") 3963 } 3964 } 3965 3966 func TestTransportDialTLSContext(t *testing.T) { 3967 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode}) 3968 } 3969 func testTransportDialTLSContext(t *testing.T, mode testMode) { 3970 var mu sync.Mutex // guards following 3971 var gotReq bool 3972 var receivedContext context.Context 3973 3974 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 3975 mu.Lock() 3976 gotReq = true 3977 mu.Unlock() 3978 })).ts 3979 c := ts.Client() 3980 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { 3981 mu.Lock() 3982 receivedContext = ctx 3983 mu.Unlock() 3984 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) 3985 if err != nil { 3986 return nil, err 3987 } 3988 return c, c.HandshakeContext(ctx) 3989 } 3990 3991 req, err := NewRequest("GET", ts.URL, nil) 3992 if err != nil { 3993 t.Fatal(err) 3994 } 3995 ctx := context.WithValue(context.Background(), "some-key", "some-value") 3996 res, err := c.Do(req.WithContext(ctx)) 3997 if err != nil { 3998 t.Fatal(err) 3999 } 4000 res.Body.Close() 4001 mu.Lock() 4002 if !gotReq { 4003 t.Error("didn't get request") 4004 } 4005 if receivedContext != ctx { 4006 t.Error("didn't receive correct context") 4007 } 4008 } 4009 4010 // Test for issue 8755 4011 // Ensure that if a proxy returns an error, it is exposed by RoundTrip 4012 func TestRoundTripReturnsProxyError(t *testing.T) { 4013 badProxy := func(*Request) (*url.URL, error) { 4014 return nil, errors.New("errorMessage") 4015 } 4016 4017 tr := &Transport{Proxy: badProxy} 4018 4019 req, _ := NewRequest("GET", "http://example.com", nil) 4020 4021 _, err := tr.RoundTrip(req) 4022 4023 if err == nil { 4024 t.Error("Expected proxy error to be returned by RoundTrip") 4025 } 4026 } 4027 4028 // tests that putting an idle conn after a call to CloseIdleConns does return it 4029 func TestTransportCloseIdleConnsThenReturn(t *testing.T) { 4030 tr := &Transport{} 4031 wantIdle := func(when string, n int) bool { 4032 got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn 4033 if got == n { 4034 return true 4035 } 4036 t.Errorf("%s: idle conns = %d; want %d", when, got, n) 4037 return false 4038 } 4039 wantIdle("start", 0) 4040 if !tr.PutIdleTestConn("http", "example.com") { 4041 t.Fatal("put failed") 4042 } 4043 if !tr.PutIdleTestConn("http", "example.com") { 4044 t.Fatal("second put failed") 4045 } 4046 wantIdle("after put", 2) 4047 tr.CloseIdleConnections() 4048 if !tr.IsIdleForTesting() { 4049 t.Error("should be idle after CloseIdleConnections") 4050 } 4051 wantIdle("after close idle", 0) 4052 if tr.PutIdleTestConn("http", "example.com") { 4053 t.Fatal("put didn't fail") 4054 } 4055 wantIdle("after second put", 0) 4056 4057 tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode 4058 if tr.IsIdleForTesting() { 4059 t.Error("shouldn't be idle after QueueForIdleConnForTesting") 4060 } 4061 if !tr.PutIdleTestConn("http", "example.com") { 4062 t.Fatal("after re-activation") 4063 } 4064 wantIdle("after final put", 1) 4065 } 4066 4067 // Test for issue 34282 4068 // Ensure that getConn doesn't call the GotConn trace hook on an HTTP/2 idle conn 4069 func TestTransportTraceGotConnH2IdleConns(t *testing.T) { 4070 tr := &Transport{} 4071 wantIdle := func(when string, n int) bool { 4072 got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2 4073 if got == n { 4074 return true 4075 } 4076 t.Errorf("%s: idle conns = %d; want %d", when, got, n) 4077 return false 4078 } 4079 wantIdle("start", 0) 4080 alt := funcRoundTripper(func() {}) 4081 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) { 4082 t.Fatal("put failed") 4083 } 4084 wantIdle("after put", 1) 4085 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ 4086 GotConn: func(httptrace.GotConnInfo) { 4087 // tr.getConn should leave it for the HTTP/2 alt to call GotConn. 4088 t.Error("GotConn called") 4089 }, 4090 }) 4091 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil) 4092 _, err := tr.RoundTrip(req) 4093 if err != errFakeRoundTrip { 4094 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip) 4095 } 4096 wantIdle("after round trip", 1) 4097 } 4098 4099 func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { 4100 run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode}) 4101 } 4102 func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) { 4103 if testing.Short() { 4104 t.Skip("skipping in short mode") 4105 } 4106 4107 timeout := 1 * time.Millisecond 4108 retry := true 4109 for retry { 4110 trFunc := func(tr *Transport) { 4111 tr.MaxConnsPerHost = 1 4112 tr.MaxIdleConnsPerHost = 1 4113 tr.IdleConnTimeout = timeout 4114 } 4115 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) 4116 4117 retry = false 4118 tooShort := func(err error) bool { 4119 if err == nil || !strings.Contains(err.Error(), "use of closed network connection") { 4120 return false 4121 } 4122 if !retry { 4123 t.Helper() 4124 t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout) 4125 timeout *= 2 4126 retry = true 4127 cst.close() 4128 } 4129 return true 4130 } 4131 4132 if _, err := cst.c.Get(cst.ts.URL); err != nil { 4133 if tooShort(err) { 4134 continue 4135 } 4136 t.Fatalf("got error: %s", err) 4137 } 4138 4139 time.Sleep(10 * timeout) 4140 if _, err := cst.c.Get(cst.ts.URL); err != nil { 4141 if tooShort(err) { 4142 continue 4143 } 4144 t.Fatalf("got error: %s", err) 4145 } 4146 } 4147 } 4148 4149 // This tests that a client requesting a content range won't also 4150 // implicitly ask for gzip support. If they want that, they need to do it 4151 // on their own. 4152 // golang.org/issue/8923 4153 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) } 4154 func testTransportRangeAndGzip(t *testing.T, mode testMode) { 4155 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 4156 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { 4157 t.Error("Transport advertised gzip support in the Accept header") 4158 } 4159 if r.Header.Get("Range") == "" { 4160 t.Error("no Range in request") 4161 } 4162 })).ts 4163 c := ts.Client() 4164 4165 req, _ := NewRequest("GET", ts.URL, nil) 4166 req.Header.Set("Range", "bytes=7-11") 4167 res, err := c.Do(req) 4168 if err != nil { 4169 t.Fatal(err) 4170 } 4171 res.Body.Close() 4172 } 4173 4174 // Test for issue 10474 4175 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) } 4176 func testTransportResponseCancelRace(t *testing.T, mode testMode) { 4177 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 4178 // important that this response has a body. 4179 var b [1024]byte 4180 w.Write(b[:]) 4181 })).ts 4182 tr := ts.Client().Transport.(*Transport) 4183 4184 req, err := NewRequest("GET", ts.URL, nil) 4185 if err != nil { 4186 t.Fatal(err) 4187 } 4188 res, err := tr.RoundTrip(req) 4189 if err != nil { 4190 t.Fatal(err) 4191 } 4192 // If we do an early close, Transport just throws the connection away and 4193 // doesn't reuse it. In order to trigger the bug, it has to reuse the connection 4194 // so read the body 4195 if _, err := io.Copy(io.Discard, res.Body); err != nil { 4196 t.Fatal(err) 4197 } 4198 4199 req2, err := NewRequest("GET", ts.URL, nil) 4200 if err != nil { 4201 t.Fatal(err) 4202 } 4203 tr.CancelRequest(req) 4204 res, err = tr.RoundTrip(req2) 4205 if err != nil { 4206 t.Fatal(err) 4207 } 4208 res.Body.Close() 4209 } 4210 4211 // Test for issue 19248: Content-Encoding's value is case insensitive. 4212 func TestTransportContentEncodingCaseInsensitive(t *testing.T) { 4213 run(t, testTransportContentEncodingCaseInsensitive) 4214 } 4215 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) { 4216 for _, ce := range []string{"gzip", "GZIP"} { 4217 ce := ce 4218 t.Run(ce, func(t *testing.T) { 4219 const encodedString = "Hello Gopher" 4220 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 4221 w.Header().Set("Content-Encoding", ce) 4222 gz := gzip.NewWriter(w) 4223 gz.Write([]byte(encodedString)) 4224 gz.Close() 4225 })).ts 4226 4227 res, err := ts.Client().Get(ts.URL) 4228 if err != nil { 4229 t.Fatal(err) 4230 } 4231 4232 body, err := io.ReadAll(res.Body) 4233 res.Body.Close() 4234 if err != nil { 4235 t.Fatal(err) 4236 } 4237 4238 if string(body) != encodedString { 4239 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body)) 4240 } 4241 }) 4242 } 4243 } 4244 4245 func TestTransportDialCancelRace(t *testing.T) { 4246 run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode}) 4247 } 4248 func testTransportDialCancelRace(t *testing.T, mode testMode) { 4249 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts 4250 tr := ts.Client().Transport.(*Transport) 4251 4252 req, err := NewRequest("GET", ts.URL, nil) 4253 if err != nil { 4254 t.Fatal(err) 4255 } 4256 SetEnterRoundTripHook(func() { 4257 tr.CancelRequest(req) 4258 }) 4259 defer SetEnterRoundTripHook(nil) 4260 res, err := tr.RoundTrip(req) 4261 if err != ExportErrRequestCanceled { 4262 t.Errorf("expected canceled request error; got %v", err) 4263 if err == nil { 4264 res.Body.Close() 4265 } 4266 } 4267 } 4268 4269 // https://go.dev/issue/49621 4270 func TestConnClosedBeforeRequestIsWritten(t *testing.T) { 4271 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode}) 4272 } 4273 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) { 4274 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), 4275 func(tr *Transport) { 4276 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { 4277 // Connection immediately returns errors. 4278 return &funcConn{ 4279 read: func([]byte) (int, error) { 4280 return 0, errors.New("error") 4281 }, 4282 write: func([]byte) (int, error) { 4283 return 0, errors.New("error") 4284 }, 4285 }, nil 4286 } 4287 }, 4288 ).ts 4289 // Set a short delay in RoundTrip to give the persistConn time to notice 4290 // the connection is broken. We want to exercise the path where writeLoop exits 4291 // before it reads the request to send. If this delay is too short, we may instead 4292 // exercise the path where writeLoop accepts the request and then fails to write it. 4293 // That's fine, so long as we get the desired path often enough. 4294 SetEnterRoundTripHook(func() { 4295 time.Sleep(1 * time.Millisecond) 4296 }) 4297 defer SetEnterRoundTripHook(nil) 4298 var closes int 4299 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) 4300 if err == nil { 4301 t.Fatalf("expected request to fail, but it did not") 4302 } 4303 if closes != 1 { 4304 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes) 4305 } 4306 } 4307 4308 // logWritesConn is a net.Conn that logs each Write call to writes 4309 // and then proxies to w. 4310 // It proxies Read calls to a reader it receives from rch. 4311 type logWritesConn struct { 4312 net.Conn // nil. crash on use. 4313 4314 w io.Writer 4315 4316 rch <-chan io.Reader 4317 r io.Reader // nil until received by rch 4318 4319 mu sync.Mutex 4320 writes []string 4321 } 4322 4323 func (c *logWritesConn) Write(p []byte) (n int, err error) { 4324 c.mu.Lock() 4325 defer c.mu.Unlock() 4326 c.writes = append(c.writes, string(p)) 4327 return c.w.Write(p) 4328 } 4329 4330 func (c *logWritesConn) Read(p []byte) (n int, err error) { 4331 if c.r == nil { 4332 c.r = <-c.rch 4333 } 4334 return c.r.Read(p) 4335 } 4336 4337 func (c *logWritesConn) Close() error { return nil } 4338 4339 // Issue 6574 4340 func TestTransportFlushesBodyChunks(t *testing.T) { 4341 defer afterTest(t) 4342 resBody := make(chan io.Reader, 1) 4343 connr, connw := io.Pipe() // connection pipe pair 4344 lw := &logWritesConn{ 4345 rch: resBody, 4346 w: connw, 4347 } 4348 tr := &Transport{ 4349 Dial: func(network, addr string) (net.Conn, error) { 4350 return lw, nil 4351 }, 4352 } 4353 bodyr, bodyw := io.Pipe() // body pipe pair 4354 go func() { 4355 defer bodyw.Close() 4356 for i := 0; i < 3; i++ { 4357 fmt.Fprintf(bodyw, "num%d\n", i) 4358 } 4359 }() 4360 resc := make(chan *Response) 4361 go func() { 4362 req, _ := NewRequest("POST", "http://localhost:8080", bodyr) 4363 req.Header.Set("User-Agent", "x") // known value for test 4364 res, err := tr.RoundTrip(req) 4365 if err != nil { 4366 t.Errorf("RoundTrip: %v", err) 4367 close(resc) 4368 return 4369 } 4370 resc <- res 4371 4372 }() 4373 // Fully consume the request before checking the Write log vs. want. 4374 req, err := ReadRequest(bufio.NewReader(connr)) 4375 if err != nil { 4376 t.Fatal(err) 4377 } 4378 io.Copy(io.Discard, req.Body) 4379 4380 // Unblock the transport's roundTrip goroutine. 4381 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") 4382 res, ok := <-resc 4383 if !ok { 4384 return 4385 } 4386 defer res.Body.Close() 4387 4388 want := []string{ 4389 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n", 4390 "5\r\nnum0\n\r\n", 4391 "5\r\nnum1\n\r\n", 4392 "5\r\nnum2\n\r\n", 4393 "0\r\n\r\n", 4394 } 4395 if !reflect.DeepEqual(lw.writes, want) { 4396 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want) 4397 } 4398 } 4399 4400 // Issue 22088: flush Transport request headers if we're not sure the body won't block on read. 4401 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) } 4402 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) { 4403 gotReq := make(chan struct{}) 4404 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 4405 close(gotReq) 4406 })) 4407 4408 pr, pw := io.Pipe() 4409 req, err := NewRequest("POST", cst.ts.URL, pr) 4410 if err != nil { 4411 t.Fatal(err) 4412 } 4413 gotRes := make(chan struct{}) 4414 go func() { 4415 defer close(gotRes) 4416 res, err := cst.tr.RoundTrip(req) 4417 if err != nil { 4418 t.Error(err) 4419 return 4420 } 4421 res.Body.Close() 4422 }() 4423 4424 <-gotReq 4425 pw.Close() 4426 <-gotRes 4427 } 4428 4429 type wgReadCloser struct { 4430 io.Reader 4431 wg *sync.WaitGroup 4432 closed bool 4433 } 4434 4435 func (c *wgReadCloser) Close() error { 4436 if c.closed { 4437 return net.ErrClosed 4438 } 4439 c.closed = true 4440 c.wg.Done() 4441 return nil 4442 } 4443 4444 // Issue 11745. 4445 func TestTransportPrefersResponseOverWriteError(t *testing.T) { 4446 // Not parallel: modifies the global rstAvoidanceDelay. 4447 run(t, testTransportPrefersResponseOverWriteError, testNotParallel) 4448 } 4449 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) { 4450 if testing.Short() { 4451 t.Skip("skipping in short mode") 4452 } 4453 4454 runTimeSensitiveTest(t, []time.Duration{ 4455 1 * time.Millisecond, 4456 5 * time.Millisecond, 4457 10 * time.Millisecond, 4458 50 * time.Millisecond, 4459 100 * time.Millisecond, 4460 500 * time.Millisecond, 4461 time.Second, 4462 5 * time.Second, 4463 }, func(t *testing.T, timeout time.Duration) error { 4464 SetRSTAvoidanceDelay(t, timeout) 4465 t.Logf("set RST avoidance delay to %v", timeout) 4466 4467 const contentLengthLimit = 1024 * 1024 // 1MB 4468 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 4469 if r.ContentLength >= contentLengthLimit { 4470 w.WriteHeader(StatusBadRequest) 4471 r.Body.Close() 4472 return 4473 } 4474 w.WriteHeader(StatusOK) 4475 })) 4476 // We need to close cst explicitly here so that in-flight server 4477 // requests don't race with the call to SetRSTAvoidanceDelay for a retry. 4478 defer cst.close() 4479 ts := cst.ts 4480 c := ts.Client() 4481 4482 count := 100 4483 4484 bigBody := strings.Repeat("a", contentLengthLimit*2) 4485 var wg sync.WaitGroup 4486 defer wg.Wait() 4487 getBody := func() (io.ReadCloser, error) { 4488 wg.Add(1) 4489 body := &wgReadCloser{ 4490 Reader: strings.NewReader(bigBody), 4491 wg: &wg, 4492 } 4493 return body, nil 4494 } 4495 4496 for i := 0; i < count; i++ { 4497 reqBody, _ := getBody() 4498 req, err := NewRequest("PUT", ts.URL, reqBody) 4499 if err != nil { 4500 reqBody.Close() 4501 t.Fatal(err) 4502 } 4503 req.ContentLength = int64(len(bigBody)) 4504 req.GetBody = getBody 4505 4506 resp, err := c.Do(req) 4507 if err != nil { 4508 return fmt.Errorf("Do %d: %v", i, err) 4509 } else { 4510 resp.Body.Close() 4511 if resp.StatusCode != 400 { 4512 t.Errorf("Expected status code 400, got %v", resp.Status) 4513 } 4514 } 4515 } 4516 return nil 4517 }) 4518 } 4519 4520 func TestTransportAutomaticHTTP2(t *testing.T) { 4521 testTransportAutoHTTP(t, &Transport{}, true) 4522 } 4523 4524 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) { 4525 testTransportAutoHTTP(t, &Transport{ 4526 ForceAttemptHTTP2: true, 4527 TLSClientConfig: new(tls.Config), 4528 }, true) 4529 } 4530 4531 // golang.org/issue/14391: also check DefaultTransport 4532 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) { 4533 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true) 4534 } 4535 4536 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) { 4537 testTransportAutoHTTP(t, &Transport{ 4538 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper), 4539 }, false) 4540 } 4541 4542 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) { 4543 testTransportAutoHTTP(t, &Transport{ 4544 TLSClientConfig: new(tls.Config), 4545 }, false) 4546 } 4547 4548 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) { 4549 testTransportAutoHTTP(t, &Transport{ 4550 ExpectContinueTimeout: 1 * time.Second, 4551 }, true) 4552 } 4553 4554 func TestTransportAutomaticHTTP2_Dial(t *testing.T) { 4555 var d net.Dialer 4556 testTransportAutoHTTP(t, &Transport{ 4557 Dial: d.Dial, 4558 }, false) 4559 } 4560 4561 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) { 4562 var d net.Dialer 4563 testTransportAutoHTTP(t, &Transport{ 4564 DialContext: d.DialContext, 4565 }, false) 4566 } 4567 4568 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) { 4569 testTransportAutoHTTP(t, &Transport{ 4570 DialTLS: func(network, addr string) (net.Conn, error) { 4571 panic("unused") 4572 }, 4573 }, false) 4574 } 4575 4576 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) { 4577 CondSkipHTTP2(t) 4578 _, err := tr.RoundTrip(new(Request)) 4579 if err == nil { 4580 t.Error("expected error from RoundTrip") 4581 } 4582 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 { 4583 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2) 4584 } 4585 } 4586 4587 // Issue 13633: there was a race where we returned bodyless responses 4588 // to callers before recycling the persistent connection, which meant 4589 // a client doing two subsequent requests could end up on different 4590 // connections. It's somewhat harmless but enough tests assume it's 4591 // not true in order to test other things that it's worth fixing. 4592 // Plus it's nice to be consistent and not have timing-dependent 4593 // behavior. 4594 func TestTransportReuseConnEmptyResponseBody(t *testing.T) { 4595 run(t, testTransportReuseConnEmptyResponseBody) 4596 } 4597 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) { 4598 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 4599 w.Header().Set("X-Addr", r.RemoteAddr) 4600 // Empty response body. 4601 })) 4602 n := 100 4603 if testing.Short() { 4604 n = 10 4605 } 4606 var firstAddr string 4607 for i := 0; i < n; i++ { 4608 res, err := cst.c.Get(cst.ts.URL) 4609 if err != nil { 4610 log.Fatal(err) 4611 } 4612 addr := res.Header.Get("X-Addr") 4613 if i == 0 { 4614 firstAddr = addr 4615 } else if addr != firstAddr { 4616 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr) 4617 } 4618 res.Body.Close() 4619 } 4620 } 4621 4622 // Issue 13839 4623 func TestNoCrashReturningTransportAltConn(t *testing.T) { 4624 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) 4625 if err != nil { 4626 t.Fatal(err) 4627 } 4628 ln := newLocalListener(t) 4629 defer ln.Close() 4630 4631 var wg sync.WaitGroup 4632 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) 4633 defer SetPendingDialHooks(nil, nil) 4634 4635 testDone := make(chan struct{}) 4636 defer close(testDone) 4637 go func() { 4638 tln := tls.NewListener(ln, &tls.Config{ 4639 NextProtos: []string{"foo"}, 4640 Certificates: []tls.Certificate{cert}, 4641 }) 4642 sc, err := tln.Accept() 4643 if err != nil { 4644 t.Error(err) 4645 return 4646 } 4647 if err := sc.(*tls.Conn).Handshake(); err != nil { 4648 t.Error(err) 4649 return 4650 } 4651 <-testDone 4652 sc.Close() 4653 }() 4654 4655 addr := ln.Addr().String() 4656 4657 req, _ := NewRequest("GET", "https://fake.tld/", nil) 4658 cancel := make(chan struct{}) 4659 req.Cancel = cancel 4660 4661 doReturned := make(chan bool, 1) 4662 madeRoundTripper := make(chan bool, 1) 4663 4664 tr := &Transport{ 4665 DisableKeepAlives: true, 4666 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ 4667 "foo": func(authority string, c *tls.Conn) RoundTripper { 4668 madeRoundTripper <- true 4669 return funcRoundTripper(func() { 4670 t.Error("foo RoundTripper should not be called") 4671 }) 4672 }, 4673 }, 4674 Dial: func(_, _ string) (net.Conn, error) { 4675 panic("shouldn't be called") 4676 }, 4677 DialTLS: func(_, _ string) (net.Conn, error) { 4678 tc, err := tls.Dial("tcp", addr, &tls.Config{ 4679 InsecureSkipVerify: true, 4680 NextProtos: []string{"foo"}, 4681 }) 4682 if err != nil { 4683 return nil, err 4684 } 4685 if err := tc.Handshake(); err != nil { 4686 return nil, err 4687 } 4688 close(cancel) 4689 <-doReturned 4690 return tc, nil 4691 }, 4692 } 4693 c := &Client{Transport: tr} 4694 4695 _, err = c.Do(req) 4696 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn { 4697 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err) 4698 } 4699 4700 doReturned <- true 4701 <-madeRoundTripper 4702 wg.Wait() 4703 } 4704 4705 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) { 4706 run(t, func(t *testing.T, mode testMode) { 4707 testTransportReuseConnection_Gzip(t, mode, true) 4708 }) 4709 } 4710 4711 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { 4712 run(t, func(t *testing.T, mode testMode) { 4713 testTransportReuseConnection_Gzip(t, mode, false) 4714 }) 4715 } 4716 4717 // Make sure we re-use underlying TCP connection for gzipped responses too. 4718 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) { 4719 addr := make(chan string, 2) 4720 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 4721 addr <- r.RemoteAddr 4722 w.Header().Set("Content-Encoding", "gzip") 4723 if chunked { 4724 w.(Flusher).Flush() 4725 } 4726 w.Write(rgz) // arbitrary gzip response 4727 })).ts 4728 c := ts.Client() 4729 4730 trace := &httptrace.ClientTrace{ 4731 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) }, 4732 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) }, 4733 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) }, 4734 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) }, 4735 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) }, 4736 } 4737 ctx := httptrace.WithClientTrace(context.Background(), trace) 4738 4739 for i := 0; i < 2; i++ { 4740 req, _ := NewRequest("GET", ts.URL, nil) 4741 req = req.WithContext(ctx) 4742 res, err := c.Do(req) 4743 if err != nil { 4744 t.Fatal(err) 4745 } 4746 buf := make([]byte, len(rgz)) 4747 if n, err := io.ReadFull(res.Body, buf); err != nil { 4748 t.Errorf("%d. ReadFull = %v, %v", i, n, err) 4749 } 4750 // Note: no res.Body.Close call. It should work without it, 4751 // since the flate.Reader's internal buffering will hit EOF 4752 // and that should be sufficient. 4753 } 4754 a1, a2 := <-addr, <-addr 4755 if a1 != a2 { 4756 t.Fatalf("didn't reuse connection") 4757 } 4758 } 4759 4760 func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) } 4761 func testTransportResponseHeaderLength(t *testing.T, mode testMode) { 4762 if mode == http2Mode { 4763 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes") 4764 } 4765 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 4766 if r.URL.Path == "/long" { 4767 w.Header().Set("Long", strings.Repeat("a", 1<<20)) 4768 } 4769 })).ts 4770 c := ts.Client() 4771 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 4772 4773 if res, err := c.Get(ts.URL); err != nil { 4774 t.Fatal(err) 4775 } else { 4776 res.Body.Close() 4777 } 4778 4779 res, err := c.Get(ts.URL + "/long") 4780 if err == nil { 4781 defer res.Body.Close() 4782 var n int64 4783 for k, vv := range res.Header { 4784 for _, v := range vv { 4785 n += int64(len(k)) + int64(len(v)) 4786 } 4787 } 4788 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n) 4789 } 4790 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) { 4791 t.Errorf("got error: %v; want %q", err, want) 4792 } 4793 } 4794 4795 func TestTransportEventTrace(t *testing.T) { 4796 run(t, func(t *testing.T, mode testMode) { 4797 testTransportEventTrace(t, mode, false) 4798 }, testNotParallel) 4799 } 4800 4801 // test a non-nil httptrace.ClientTrace but with all hooks set to zero. 4802 func TestTransportEventTrace_NoHooks(t *testing.T) { 4803 run(t, func(t *testing.T, mode testMode) { 4804 testTransportEventTrace(t, mode, true) 4805 }, testNotParallel) 4806 } 4807 4808 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) { 4809 const resBody = "some body" 4810 gotWroteReqEvent := make(chan struct{}, 500) 4811 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 4812 if r.Method == "GET" { 4813 // Do nothing for the second request. 4814 return 4815 } 4816 if _, err := io.ReadAll(r.Body); err != nil { 4817 t.Error(err) 4818 } 4819 if !noHooks { 4820 <-gotWroteReqEvent 4821 } 4822 io.WriteString(w, resBody) 4823 }), func(tr *Transport) { 4824 if tr.TLSClientConfig != nil { 4825 tr.TLSClientConfig.InsecureSkipVerify = true 4826 } 4827 }) 4828 defer cst.close() 4829 4830 cst.tr.ExpectContinueTimeout = 1 * time.Second 4831 4832 var mu sync.Mutex // guards buf 4833 var buf strings.Builder 4834 logf := func(format string, args ...any) { 4835 mu.Lock() 4836 defer mu.Unlock() 4837 fmt.Fprintf(&buf, format, args...) 4838 buf.WriteByte('\n') 4839 } 4840 4841 addrStr := cst.ts.Listener.Addr().String() 4842 ip, port, err := net.SplitHostPort(addrStr) 4843 if err != nil { 4844 t.Fatal(err) 4845 } 4846 4847 // Install a fake DNS server. 4848 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { 4849 if host != "dns-is-faked.golang" { 4850 t.Errorf("unexpected DNS host lookup for %q/%q", network, host) 4851 return nil, nil 4852 } 4853 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil 4854 }) 4855 4856 body := "some body" 4857 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body)) 4858 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"} 4859 trace := &httptrace.ClientTrace{ 4860 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) }, 4861 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) }, 4862 GotFirstResponseByte: func() { logf("first response byte") }, 4863 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) }, 4864 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) }, 4865 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) }, 4866 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) }, 4867 ConnectDone: func(network, addr string, err error) { 4868 if err != nil { 4869 t.Errorf("ConnectDone: %v", err) 4870 } 4871 logf("ConnectDone: connected to %s %s = %v", network, addr, err) 4872 }, 4873 WroteHeaderField: func(key string, value []string) { 4874 logf("WroteHeaderField: %s: %v", key, value) 4875 }, 4876 WroteHeaders: func() { 4877 logf("WroteHeaders") 4878 }, 4879 Wait100Continue: func() { logf("Wait100Continue") }, 4880 Got100Continue: func() { logf("Got100Continue") }, 4881 WroteRequest: func(e httptrace.WroteRequestInfo) { 4882 logf("WroteRequest: %+v", e) 4883 gotWroteReqEvent <- struct{}{} 4884 }, 4885 } 4886 if mode == http2Mode { 4887 trace.TLSHandshakeStart = func() { logf("tls handshake start") } 4888 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) { 4889 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err) 4890 } 4891 } 4892 if noHooks { 4893 // zero out all func pointers, trying to get some path to crash 4894 *trace = httptrace.ClientTrace{} 4895 } 4896 req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) 4897 4898 req.Header.Set("Expect", "100-continue") 4899 res, err := cst.c.Do(req) 4900 if err != nil { 4901 t.Fatal(err) 4902 } 4903 logf("got roundtrip.response") 4904 slurp, err := io.ReadAll(res.Body) 4905 if err != nil { 4906 t.Fatal(err) 4907 } 4908 logf("consumed body") 4909 if string(slurp) != resBody || res.StatusCode != 200 { 4910 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody) 4911 } 4912 res.Body.Close() 4913 4914 if noHooks { 4915 // Done at this point. Just testing a full HTTP 4916 // requests can happen with a trace pointing to a zero 4917 // ClientTrace, full of nil func pointers. 4918 return 4919 } 4920 4921 mu.Lock() 4922 got := buf.String() 4923 mu.Unlock() 4924 4925 wantOnce := func(sub string) { 4926 if strings.Count(got, sub) != 1 { 4927 t.Errorf("expected substring %q exactly once in output.", sub) 4928 } 4929 } 4930 wantOnceOrMore := func(sub string) { 4931 if strings.Count(got, sub) == 0 { 4932 t.Errorf("expected substring %q at least once in output.", sub) 4933 } 4934 } 4935 wantOnce("Getting conn for dns-is-faked.golang:" + port) 4936 wantOnce("DNS start: {Host:dns-is-faked.golang}") 4937 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}") 4938 wantOnce("got conn: {") 4939 wantOnceOrMore("Connecting to tcp " + addrStr) 4940 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>") 4941 wantOnce("Reused:false WasIdle:false IdleTime:0s") 4942 wantOnce("first response byte") 4943 if mode == http2Mode { 4944 wantOnce("tls handshake start") 4945 wantOnce("tls handshake done") 4946 } else { 4947 wantOnce("PutIdleConn = <nil>") 4948 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]") 4949 // TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the 4950 // WroteHeaderField hook is not yet implemented in h2.) 4951 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port)) 4952 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body))) 4953 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]") 4954 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]") 4955 } 4956 wantOnce("WroteHeaders") 4957 wantOnce("Wait100Continue") 4958 wantOnce("Got100Continue") 4959 wantOnce("WroteRequest: {Err:<nil>}") 4960 if strings.Contains(got, " to udp ") { 4961 t.Errorf("should not see UDP (DNS) connections") 4962 } 4963 if t.Failed() { 4964 t.Errorf("Output:\n%s", got) 4965 } 4966 4967 // And do a second request: 4968 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil) 4969 req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) 4970 res, err = cst.c.Do(req) 4971 if err != nil { 4972 t.Fatal(err) 4973 } 4974 if res.StatusCode != 200 { 4975 t.Fatal(res.Status) 4976 } 4977 res.Body.Close() 4978 4979 mu.Lock() 4980 got = buf.String() 4981 mu.Unlock() 4982 4983 sub := "Getting conn for dns-is-faked.golang:" 4984 if gotn, want := strings.Count(got, sub), 2; gotn != want { 4985 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got) 4986 } 4987 4988 } 4989 4990 func TestTransportEventTraceTLSVerify(t *testing.T) { 4991 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode}) 4992 } 4993 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) { 4994 var mu sync.Mutex 4995 var buf strings.Builder 4996 logf := func(format string, args ...any) { 4997 mu.Lock() 4998 defer mu.Unlock() 4999 fmt.Fprintf(&buf, format, args...) 5000 buf.WriteByte('\n') 5001 } 5002 5003 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 5004 t.Error("Unexpected request") 5005 }), func(ts *httptest.Server) { 5006 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { 5007 logf("%s", p) 5008 return len(p), nil 5009 }), "", 0) 5010 }).ts 5011 5012 certpool := x509.NewCertPool() 5013 certpool.AddCert(ts.Certificate()) 5014 5015 c := &Client{Transport: &Transport{ 5016 TLSClientConfig: &tls.Config{ 5017 ServerName: "dns-is-faked.golang", 5018 RootCAs: certpool, 5019 }, 5020 }} 5021 5022 trace := &httptrace.ClientTrace{ 5023 TLSHandshakeStart: func() { logf("TLSHandshakeStart") }, 5024 TLSHandshakeDone: func(s tls.ConnectionState, err error) { 5025 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err) 5026 }, 5027 } 5028 5029 req, _ := NewRequest("GET", ts.URL, nil) 5030 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) 5031 _, err := c.Do(req) 5032 if err == nil { 5033 t.Error("Expected request to fail TLS verification") 5034 } 5035 5036 mu.Lock() 5037 got := buf.String() 5038 mu.Unlock() 5039 5040 wantOnce := func(sub string) { 5041 if strings.Count(got, sub) != 1 { 5042 t.Errorf("expected substring %q exactly once in output.", sub) 5043 } 5044 } 5045 5046 wantOnce("TLSHandshakeStart") 5047 wantOnce("TLSHandshakeDone") 5048 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com") 5049 5050 if t.Failed() { 5051 t.Errorf("Output:\n%s", got) 5052 } 5053 } 5054 5055 var ( 5056 isDNSHijackedOnce sync.Once 5057 isDNSHijacked bool 5058 ) 5059 5060 func skipIfDNSHijacked(t *testing.T) { 5061 // Skip this test if the user is using a shady/ISP 5062 // DNS server hijacking queries. 5063 // See issues 16732, 16716. 5064 isDNSHijackedOnce.Do(func() { 5065 addrs, _ := net.LookupHost("dns-should-not-resolve.golang") 5066 isDNSHijacked = len(addrs) != 0 5067 }) 5068 if isDNSHijacked { 5069 t.Skip("skipping; test requires non-hijacking DNS server") 5070 } 5071 } 5072 5073 func TestTransportEventTraceRealDNS(t *testing.T) { 5074 skipIfDNSHijacked(t) 5075 defer afterTest(t) 5076 tr := &Transport{} 5077 defer tr.CloseIdleConnections() 5078 c := &Client{Transport: tr} 5079 5080 var mu sync.Mutex // guards buf 5081 var buf strings.Builder 5082 logf := func(format string, args ...any) { 5083 mu.Lock() 5084 defer mu.Unlock() 5085 fmt.Fprintf(&buf, format, args...) 5086 buf.WriteByte('\n') 5087 } 5088 5089 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil) 5090 trace := &httptrace.ClientTrace{ 5091 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) }, 5092 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) }, 5093 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) }, 5094 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) }, 5095 } 5096 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) 5097 5098 resp, err := c.Do(req) 5099 if err == nil { 5100 resp.Body.Close() 5101 t.Fatal("expected error during DNS lookup") 5102 } 5103 5104 mu.Lock() 5105 got := buf.String() 5106 mu.Unlock() 5107 5108 wantSub := func(sub string) { 5109 if !strings.Contains(got, sub) { 5110 t.Errorf("expected substring %q in output.", sub) 5111 } 5112 } 5113 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}") 5114 wantSub("DNSDone: {Addrs:[] Err:") 5115 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") { 5116 t.Errorf("should not see Connect events") 5117 } 5118 if t.Failed() { 5119 t.Errorf("Output:\n%s", got) 5120 } 5121 } 5122 5123 // Issue 14353: port can only contain digits. 5124 func TestTransportRejectsAlphaPort(t *testing.T) { 5125 res, err := Get("http://dummy.tld:123foo/bar") 5126 if err == nil { 5127 res.Body.Close() 5128 t.Fatal("unexpected success") 5129 } 5130 ue, ok := err.(*url.Error) 5131 if !ok { 5132 t.Fatalf("got %#v; want *url.Error", err) 5133 } 5134 got := ue.Err.Error() 5135 want := `invalid port ":123foo" after host` 5136 if got != want { 5137 t.Errorf("got error %q; want %q", got, want) 5138 } 5139 } 5140 5141 // Test the httptrace.TLSHandshake{Start,Done} hooks with an https http1 5142 // connections. The http2 test is done in TestTransportEventTrace_h2 5143 func TestTLSHandshakeTrace(t *testing.T) { 5144 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode}) 5145 } 5146 func testTLSHandshakeTrace(t *testing.T, mode testMode) { 5147 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts 5148 5149 var mu sync.Mutex 5150 var start, done bool 5151 trace := &httptrace.ClientTrace{ 5152 TLSHandshakeStart: func() { 5153 mu.Lock() 5154 defer mu.Unlock() 5155 start = true 5156 }, 5157 TLSHandshakeDone: func(s tls.ConnectionState, err error) { 5158 mu.Lock() 5159 defer mu.Unlock() 5160 done = true 5161 if err != nil { 5162 t.Fatal("Expected error to be nil but was:", err) 5163 } 5164 }, 5165 } 5166 5167 c := ts.Client() 5168 req, err := NewRequest("GET", ts.URL, nil) 5169 if err != nil { 5170 t.Fatal("Unable to construct test request:", err) 5171 } 5172 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) 5173 5174 r, err := c.Do(req) 5175 if err != nil { 5176 t.Fatal("Unexpected error making request:", err) 5177 } 5178 r.Body.Close() 5179 mu.Lock() 5180 defer mu.Unlock() 5181 if !start { 5182 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't") 5183 } 5184 if !done { 5185 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't") 5186 } 5187 } 5188 5189 func TestTransportMaxIdleConns(t *testing.T) { 5190 run(t, testTransportMaxIdleConns, []testMode{http1Mode}) 5191 } 5192 func testTransportMaxIdleConns(t *testing.T, mode testMode) { 5193 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 5194 // No body for convenience. 5195 })).ts 5196 c := ts.Client() 5197 tr := c.Transport.(*Transport) 5198 tr.MaxIdleConns = 4 5199 5200 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String()) 5201 if err != nil { 5202 t.Fatal(err) 5203 } 5204 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) { 5205 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil 5206 }) 5207 5208 hitHost := func(n int) { 5209 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil) 5210 req = req.WithContext(ctx) 5211 res, err := c.Do(req) 5212 if err != nil { 5213 t.Fatal(err) 5214 } 5215 res.Body.Close() 5216 } 5217 for i := 0; i < 4; i++ { 5218 hitHost(i) 5219 } 5220 want := []string{ 5221 "|http|host-0.dns-is-faked.golang:" + port, 5222 "|http|host-1.dns-is-faked.golang:" + port, 5223 "|http|host-2.dns-is-faked.golang:" + port, 5224 "|http|host-3.dns-is-faked.golang:" + port, 5225 } 5226 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) { 5227 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want) 5228 } 5229 5230 // Now hitting the 5th host should kick out the first host: 5231 hitHost(4) 5232 want = []string{ 5233 "|http|host-1.dns-is-faked.golang:" + port, 5234 "|http|host-2.dns-is-faked.golang:" + port, 5235 "|http|host-3.dns-is-faked.golang:" + port, 5236 "|http|host-4.dns-is-faked.golang:" + port, 5237 } 5238 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) { 5239 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want) 5240 } 5241 } 5242 5243 func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) } 5244 func testTransportIdleConnTimeout(t *testing.T, mode testMode) { 5245 if testing.Short() { 5246 t.Skip("skipping in short mode") 5247 } 5248 5249 timeout := 1 * time.Millisecond 5250 timeoutLoop: 5251 for { 5252 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 5253 // No body for convenience. 5254 })) 5255 tr := cst.tr 5256 tr.IdleConnTimeout = timeout 5257 defer tr.CloseIdleConnections() 5258 c := &Client{Transport: tr} 5259 5260 idleConns := func() []string { 5261 if mode == http2Mode { 5262 return tr.IdleConnStrsForTesting_h2() 5263 } else { 5264 return tr.IdleConnStrsForTesting() 5265 } 5266 } 5267 5268 var conn string 5269 doReq := func(n int) (timeoutOk bool) { 5270 req, _ := NewRequest("GET", cst.ts.URL, nil) 5271 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ 5272 PutIdleConn: func(err error) { 5273 if err != nil { 5274 t.Errorf("failed to keep idle conn: %v", err) 5275 } 5276 }, 5277 })) 5278 res, err := c.Do(req) 5279 if err != nil { 5280 if strings.Contains(err.Error(), "use of closed network connection") { 5281 t.Logf("req %v: connection closed prematurely", n) 5282 return false 5283 } 5284 } 5285 res.Body.Close() 5286 conns := idleConns() 5287 if len(conns) != 1 { 5288 if len(conns) == 0 { 5289 t.Logf("req %v: no idle conns", n) 5290 return false 5291 } 5292 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns) 5293 } 5294 if conn == "" { 5295 conn = conns[0] 5296 } 5297 if conn != conns[0] { 5298 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n) 5299 return false 5300 } 5301 return true 5302 } 5303 for i := 0; i < 3; i++ { 5304 if !doReq(i) { 5305 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout) 5306 timeout *= 2 5307 cst.close() 5308 continue timeoutLoop 5309 } 5310 time.Sleep(timeout / 2) 5311 } 5312 5313 waitCondition(t, timeout/2, func(d time.Duration) bool { 5314 if got := idleConns(); len(got) != 0 { 5315 if d >= timeout*3/2 { 5316 t.Logf("after %v, idle conns = %q", d, got) 5317 } 5318 return false 5319 } 5320 return true 5321 }) 5322 break 5323 } 5324 } 5325 5326 // Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an 5327 // HTTP/2 connection was established but its caller no longer 5328 // wanted it. (Assuming the connection cache was enabled, which it is 5329 // by default) 5330 // 5331 // This test reproduced the crash by setting the IdleConnTimeout low 5332 // (to make the test reasonable) and then making a request which is 5333 // canceled by the DialTLS hook, which then also waits to return the 5334 // real connection until after the RoundTrip saw the error. Then we 5335 // know the successful tls.Dial from DialTLS will need to go into the 5336 // idle pool. Then we give it a of time to explode. 5337 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) } 5338 func testIdleConnH2Crash(t *testing.T, mode testMode) { 5339 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 5340 // nothing 5341 })) 5342 5343 ctx, cancel := context.WithCancel(context.Background()) 5344 defer cancel() 5345 5346 sawDoErr := make(chan bool, 1) 5347 testDone := make(chan struct{}) 5348 defer close(testDone) 5349 5350 cst.tr.IdleConnTimeout = 5 * time.Millisecond 5351 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) { 5352 c, err := tls.Dial(network, addr, &tls.Config{ 5353 InsecureSkipVerify: true, 5354 NextProtos: []string{"h2"}, 5355 }) 5356 if err != nil { 5357 t.Error(err) 5358 return nil, err 5359 } 5360 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" { 5361 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2") 5362 c.Close() 5363 return nil, errors.New("bogus") 5364 } 5365 5366 cancel() 5367 5368 select { 5369 case <-sawDoErr: 5370 case <-testDone: 5371 } 5372 return c, nil 5373 } 5374 5375 req, _ := NewRequest("GET", cst.ts.URL, nil) 5376 req = req.WithContext(ctx) 5377 res, err := cst.c.Do(req) 5378 if err == nil { 5379 res.Body.Close() 5380 t.Fatal("unexpected success") 5381 } 5382 sawDoErr <- true 5383 5384 // Wait for the explosion. 5385 time.Sleep(cst.tr.IdleConnTimeout * 10) 5386 } 5387 5388 type funcConn struct { 5389 net.Conn 5390 read func([]byte) (int, error) 5391 write func([]byte) (int, error) 5392 } 5393 5394 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) } 5395 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) } 5396 func (c funcConn) Close() error { return nil } 5397 5398 // Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek 5399 // back to the caller. 5400 func TestTransportReturnsPeekError(t *testing.T) { 5401 errValue := errors.New("specific error value") 5402 5403 wrote := make(chan struct{}) 5404 var wroteOnce sync.Once 5405 5406 tr := &Transport{ 5407 Dial: func(network, addr string) (net.Conn, error) { 5408 c := funcConn{ 5409 read: func([]byte) (int, error) { 5410 <-wrote 5411 return 0, errValue 5412 }, 5413 write: func(p []byte) (int, error) { 5414 wroteOnce.Do(func() { close(wrote) }) 5415 return len(p), nil 5416 }, 5417 } 5418 return c, nil 5419 }, 5420 } 5421 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil)) 5422 if err != errValue { 5423 t.Errorf("error = %#v; want %v", err, errValue) 5424 } 5425 } 5426 5427 // Issue 13835: international domain names should work 5428 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) } 5429 func testTransportIDNA(t *testing.T, mode testMode) { 5430 const uniDomain = "гофер.го" 5431 const punyDomain = "xn--c1ae0ajs.xn--c1aw" 5432 5433 var port string 5434 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 5435 want := punyDomain + ":" + port 5436 if r.Host != want { 5437 t.Errorf("Host header = %q; want %q", r.Host, want) 5438 } 5439 if mode == http2Mode { 5440 if r.TLS == nil { 5441 t.Errorf("r.TLS == nil") 5442 } else if r.TLS.ServerName != punyDomain { 5443 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain) 5444 } 5445 } 5446 w.Header().Set("Hit-Handler", "1") 5447 }), func(tr *Transport) { 5448 if tr.TLSClientConfig != nil { 5449 tr.TLSClientConfig.InsecureSkipVerify = true 5450 } 5451 }) 5452 5453 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String()) 5454 if err != nil { 5455 t.Fatal(err) 5456 } 5457 5458 // Install a fake DNS server. 5459 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { 5460 if host != punyDomain { 5461 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain) 5462 return nil, nil 5463 } 5464 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil 5465 }) 5466 5467 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil) 5468 trace := &httptrace.ClientTrace{ 5469 GetConn: func(hostPort string) { 5470 want := net.JoinHostPort(punyDomain, port) 5471 if hostPort != want { 5472 t.Errorf("getting conn for %q; want %q", hostPort, want) 5473 } 5474 }, 5475 DNSStart: func(e httptrace.DNSStartInfo) { 5476 if e.Host != punyDomain { 5477 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain) 5478 } 5479 }, 5480 } 5481 req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) 5482 5483 res, err := cst.tr.RoundTrip(req) 5484 if err != nil { 5485 t.Fatal(err) 5486 } 5487 defer res.Body.Close() 5488 if res.Header.Get("Hit-Handler") != "1" { 5489 out, err := httputil.DumpResponse(res, true) 5490 if err != nil { 5491 t.Fatal(err) 5492 } 5493 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out) 5494 } 5495 } 5496 5497 // Issue 13290: send User-Agent in proxy CONNECT 5498 func TestTransportProxyConnectHeader(t *testing.T) { 5499 run(t, testTransportProxyConnectHeader, []testMode{http1Mode}) 5500 } 5501 func testTransportProxyConnectHeader(t *testing.T, mode testMode) { 5502 reqc := make(chan *Request, 1) 5503 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 5504 if r.Method != "CONNECT" { 5505 t.Errorf("method = %q; want CONNECT", r.Method) 5506 } 5507 reqc <- r 5508 c, _, err := w.(Hijacker).Hijack() 5509 if err != nil { 5510 t.Errorf("Hijack: %v", err) 5511 return 5512 } 5513 c.Close() 5514 })).ts 5515 5516 c := ts.Client() 5517 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { 5518 return url.Parse(ts.URL) 5519 } 5520 c.Transport.(*Transport).ProxyConnectHeader = Header{ 5521 "User-Agent": {"foo"}, 5522 "Other": {"bar"}, 5523 } 5524 5525 res, err := c.Get("https://dummy.tld/") // https to force a CONNECT 5526 if err == nil { 5527 res.Body.Close() 5528 t.Errorf("unexpected success") 5529 } 5530 5531 r := <-reqc 5532 if got, want := r.Header.Get("User-Agent"), "foo"; got != want { 5533 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) 5534 } 5535 if got, want := r.Header.Get("Other"), "bar"; got != want { 5536 t.Errorf("CONNECT request Other = %q; want %q", got, want) 5537 } 5538 } 5539 5540 func TestTransportProxyGetConnectHeader(t *testing.T) { 5541 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode}) 5542 } 5543 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) { 5544 reqc := make(chan *Request, 1) 5545 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 5546 if r.Method != "CONNECT" { 5547 t.Errorf("method = %q; want CONNECT", r.Method) 5548 } 5549 reqc <- r 5550 c, _, err := w.(Hijacker).Hijack() 5551 if err != nil { 5552 t.Errorf("Hijack: %v", err) 5553 return 5554 } 5555 c.Close() 5556 })).ts 5557 5558 c := ts.Client() 5559 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { 5560 return url.Parse(ts.URL) 5561 } 5562 // These should be ignored: 5563 c.Transport.(*Transport).ProxyConnectHeader = Header{ 5564 "User-Agent": {"foo"}, 5565 "Other": {"bar"}, 5566 } 5567 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) { 5568 return Header{ 5569 "User-Agent": {"foo2"}, 5570 "Other": {"bar2"}, 5571 }, nil 5572 } 5573 5574 res, err := c.Get("https://dummy.tld/") // https to force a CONNECT 5575 if err == nil { 5576 res.Body.Close() 5577 t.Errorf("unexpected success") 5578 } 5579 5580 r := <-reqc 5581 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { 5582 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) 5583 } 5584 if got, want := r.Header.Get("Other"), "bar2"; got != want { 5585 t.Errorf("CONNECT request Other = %q; want %q", got, want) 5586 } 5587 } 5588 5589 var errFakeRoundTrip = errors.New("fake roundtrip") 5590 5591 type funcRoundTripper func() 5592 5593 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) { 5594 fn() 5595 return nil, errFakeRoundTrip 5596 } 5597 5598 func wantBody(res *Response, err error, want string) error { 5599 if err != nil { 5600 return err 5601 } 5602 slurp, err := io.ReadAll(res.Body) 5603 if err != nil { 5604 return fmt.Errorf("error reading body: %v", err) 5605 } 5606 if string(slurp) != want { 5607 return fmt.Errorf("body = %q; want %q", slurp, want) 5608 } 5609 if err := res.Body.Close(); err != nil { 5610 return fmt.Errorf("body Close = %v", err) 5611 } 5612 return nil 5613 } 5614 5615 func newLocalListener(t *testing.T) net.Listener { 5616 ln, err := net.Listen("tcp", "127.0.0.1:0") 5617 if err != nil { 5618 ln, err = net.Listen("tcp6", "[::1]:0") 5619 } 5620 if err != nil { 5621 t.Fatal(err) 5622 } 5623 return ln 5624 } 5625 5626 type countCloseReader struct { 5627 n *int 5628 io.Reader 5629 } 5630 5631 func (cr countCloseReader) Close() error { 5632 (*cr.n)++ 5633 return nil 5634 } 5635 5636 // rgz is a gzip quine that uncompresses to itself. 5637 var rgz = []byte{ 5638 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, 5639 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, 5640 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, 5641 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 5642 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 5643 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, 5644 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, 5645 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, 5646 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 5647 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, 5648 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, 5649 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, 5650 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, 5651 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 5652 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 5653 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 5654 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 5655 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 5656 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 5657 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 5658 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 5659 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, 5660 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, 5661 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 5662 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 5663 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 5664 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 5665 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, 5666 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, 5667 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, 5668 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, 5669 0x00, 0x00, 5670 } 5671 5672 // Ensure that a missing status doesn't make the server panic 5673 // See Issue https://golang.org/issues/21701 5674 func TestMissingStatusNoPanic(t *testing.T) { 5675 t.Parallel() 5676 5677 const want = "unknown status code" 5678 5679 ln := newLocalListener(t) 5680 addr := ln.Addr().String() 5681 done := make(chan bool) 5682 fullAddrURL := fmt.Sprintf("http://%s", addr) 5683 raw := "HTTP/1.1 400\r\n" + 5684 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + 5685 "Content-Type: text/html; charset=utf-8\r\n" + 5686 "Content-Length: 10\r\n" + 5687 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" + 5688 "Vary: Accept-Encoding\r\n\r\n" + 5689 "Aloha Olaa" 5690 5691 go func() { 5692 defer close(done) 5693 5694 conn, _ := ln.Accept() 5695 if conn != nil { 5696 io.WriteString(conn, raw) 5697 io.ReadAll(conn) 5698 conn.Close() 5699 } 5700 }() 5701 5702 proxyURL, err := url.Parse(fullAddrURL) 5703 if err != nil { 5704 t.Fatalf("proxyURL: %v", err) 5705 } 5706 5707 tr := &Transport{Proxy: ProxyURL(proxyURL)} 5708 5709 req, _ := NewRequest("GET", "https://golang.org/", nil) 5710 res, err, panicked := doFetchCheckPanic(tr, req) 5711 if panicked { 5712 t.Error("panicked, expecting an error") 5713 } 5714 if res != nil && res.Body != nil { 5715 io.Copy(io.Discard, res.Body) 5716 res.Body.Close() 5717 } 5718 5719 if err == nil || !strings.Contains(err.Error(), want) { 5720 t.Errorf("got=%v want=%q", err, want) 5721 } 5722 5723 ln.Close() 5724 <-done 5725 } 5726 5727 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) { 5728 defer func() { 5729 if r := recover(); r != nil { 5730 panicked = true 5731 } 5732 }() 5733 res, err = tr.RoundTrip(req) 5734 return 5735 } 5736 5737 // Issue 22330: do not allow the response body to be read when the status code 5738 // forbids a response body. 5739 func TestNoBodyOnChunked304Response(t *testing.T) { 5740 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode}) 5741 } 5742 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) { 5743 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 5744 conn, buf, _ := w.(Hijacker).Hijack() 5745 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n")) 5746 buf.Flush() 5747 conn.Close() 5748 })) 5749 5750 // Our test server above is sending back bogus data after the 5751 // response (the "0\r\n\r\n" part), which causes the Transport 5752 // code to log spam. Disable keep-alives so we never even try 5753 // to reuse the connection. 5754 cst.tr.DisableKeepAlives = true 5755 5756 res, err := cst.c.Get(cst.ts.URL) 5757 if err != nil { 5758 t.Fatal(err) 5759 } 5760 5761 if res.Body != NoBody { 5762 t.Errorf("Unexpected body on 304 response") 5763 } 5764 } 5765 5766 type funcWriter func([]byte) (int, error) 5767 5768 func (f funcWriter) Write(p []byte) (int, error) { return f(p) } 5769 5770 type doneContext struct { 5771 context.Context 5772 err error 5773 } 5774 5775 func (doneContext) Done() <-chan struct{} { 5776 c := make(chan struct{}) 5777 close(c) 5778 return c 5779 } 5780 5781 func (d doneContext) Err() error { return d.err } 5782 5783 // Issue 25852: Transport should check whether Context is done early. 5784 func TestTransportCheckContextDoneEarly(t *testing.T) { 5785 tr := &Transport{} 5786 req, _ := NewRequest("GET", "http://fake.example/", nil) 5787 wantErr := errors.New("some error") 5788 req = req.WithContext(doneContext{context.Background(), wantErr}) 5789 _, err := tr.RoundTrip(req) 5790 if err != wantErr { 5791 t.Errorf("error = %v; want %v", err, wantErr) 5792 } 5793 } 5794 5795 // Issue 23399: verify that if a client request times out, the Transport's 5796 // conn is closed so that it's not reused. 5797 // 5798 // This is the test variant that times out before the server replies with 5799 // any response headers. 5800 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { 5801 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode}) 5802 } 5803 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) { 5804 timeout := 1 * time.Millisecond 5805 for { 5806 inHandler := make(chan bool) 5807 cancelHandler := make(chan struct{}) 5808 handlerDone := make(chan bool) 5809 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 5810 <-r.Context().Done() 5811 5812 select { 5813 case <-cancelHandler: 5814 return 5815 case inHandler <- true: 5816 } 5817 defer func() { handlerDone <- true }() 5818 5819 // Read from the conn until EOF to verify that it was correctly closed. 5820 conn, _, err := w.(Hijacker).Hijack() 5821 if err != nil { 5822 t.Error(err) 5823 return 5824 } 5825 n, err := conn.Read([]byte{0}) 5826 if n != 0 || err != io.EOF { 5827 t.Errorf("unexpected Read result: %v, %v", n, err) 5828 } 5829 conn.Close() 5830 })) 5831 5832 cst.c.Timeout = timeout 5833 5834 _, err := cst.c.Get(cst.ts.URL) 5835 if err == nil { 5836 close(cancelHandler) 5837 t.Fatal("unexpected Get success") 5838 } 5839 5840 tooSlow := time.NewTimer(timeout * 10) 5841 select { 5842 case <-tooSlow.C: 5843 // If we didn't get into the Handler, that probably means the builder was 5844 // just slow and the Get failed in that time but never made it to the 5845 // server. That's fine; we'll try again with a longer timeout. 5846 t.Logf("no handler seen in %v; retrying with longer timeout", timeout) 5847 close(cancelHandler) 5848 cst.close() 5849 timeout *= 2 5850 continue 5851 case <-inHandler: 5852 tooSlow.Stop() 5853 <-handlerDone 5854 } 5855 break 5856 } 5857 } 5858 5859 // Issue 23399: verify that if a client request times out, the Transport's 5860 // conn is closed so that it's not reused. 5861 // 5862 // This is the test variant that has the server send response headers 5863 // first, and time out during the write of the response body. 5864 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { 5865 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode}) 5866 } 5867 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) { 5868 inHandler := make(chan bool) 5869 cancelHandler := make(chan struct{}) 5870 handlerDone := make(chan bool) 5871 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 5872 w.Header().Set("Content-Length", "100") 5873 w.(Flusher).Flush() 5874 5875 select { 5876 case <-cancelHandler: 5877 return 5878 case inHandler <- true: 5879 } 5880 defer func() { handlerDone <- true }() 5881 5882 conn, _, err := w.(Hijacker).Hijack() 5883 if err != nil { 5884 t.Error(err) 5885 return 5886 } 5887 conn.Write([]byte("foo")) 5888 5889 n, err := conn.Read([]byte{0}) 5890 // The error should be io.EOF or "read tcp 5891 // 127.0.0.1:35827->127.0.0.1:40290: read: connection 5892 // reset by peer" depending on timing. Really we just 5893 // care that it returns at all. But if it returns with 5894 // data, that's weird. 5895 if n != 0 || err == nil { 5896 t.Errorf("unexpected Read result: %v, %v", n, err) 5897 } 5898 conn.Close() 5899 })) 5900 5901 // Set Timeout to something very long but non-zero to exercise 5902 // the codepaths that check for it. But rather than wait for it to fire 5903 // (which would make the test slow), we send on the req.Cancel channel instead, 5904 // which happens to exercise the same code paths. 5905 cst.c.Timeout = 24 * time.Hour // just to be non-zero, not to hit it. 5906 req, _ := NewRequest("GET", cst.ts.URL, nil) 5907 cancelReq := make(chan struct{}) 5908 req.Cancel = cancelReq 5909 5910 res, err := cst.c.Do(req) 5911 if err != nil { 5912 close(cancelHandler) 5913 t.Fatalf("Get error: %v", err) 5914 } 5915 5916 // Cancel the request while the handler is still blocked on sending to the 5917 // inHandler channel. Then read it until it fails, to verify that the 5918 // connection is broken before the handler itself closes it. 5919 close(cancelReq) 5920 got, err := io.ReadAll(res.Body) 5921 if err == nil { 5922 t.Errorf("unexpected success; read %q, nil", got) 5923 } 5924 5925 // Now unblock the handler and wait for it to complete. 5926 <-inHandler 5927 <-handlerDone 5928 } 5929 5930 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { 5931 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode}) 5932 } 5933 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) { 5934 done := make(chan struct{}) 5935 defer close(done) 5936 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 5937 conn, _, err := w.(Hijacker).Hijack() 5938 if err != nil { 5939 t.Error(err) 5940 return 5941 } 5942 defer conn.Close() 5943 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n") 5944 bs := bufio.NewScanner(conn) 5945 bs.Scan() 5946 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text())) 5947 <-done 5948 })) 5949 5950 req, _ := NewRequest("GET", cst.ts.URL, nil) 5951 req.Header.Set("Upgrade", "foo") 5952 req.Header.Set("Connection", "upgrade") 5953 res, err := cst.c.Do(req) 5954 if err != nil { 5955 t.Fatal(err) 5956 } 5957 if res.StatusCode != 101 { 5958 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header) 5959 } 5960 rwc, ok := res.Body.(io.ReadWriteCloser) 5961 if !ok { 5962 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body) 5963 } 5964 defer rwc.Close() 5965 bs := bufio.NewScanner(rwc) 5966 if !bs.Scan() { 5967 t.Fatalf("expected readable input") 5968 } 5969 if got, want := bs.Text(), "Some buffered data"; got != want { 5970 t.Errorf("read %q; want %q", got, want) 5971 } 5972 io.WriteString(rwc, "echo\n") 5973 if !bs.Scan() { 5974 t.Fatalf("expected another line") 5975 } 5976 if got, want := bs.Text(), "ECHO"; got != want { 5977 t.Errorf("read %q; want %q", got, want) 5978 } 5979 } 5980 5981 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) } 5982 func testTransportCONNECTBidi(t *testing.T, mode testMode) { 5983 const target = "backend:443" 5984 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 5985 if r.Method != "CONNECT" { 5986 t.Errorf("unexpected method %q", r.Method) 5987 w.WriteHeader(500) 5988 return 5989 } 5990 if r.RequestURI != target { 5991 t.Errorf("unexpected CONNECT target %q", r.RequestURI) 5992 w.WriteHeader(500) 5993 return 5994 } 5995 nc, brw, err := w.(Hijacker).Hijack() 5996 if err != nil { 5997 t.Error(err) 5998 return 5999 } 6000 defer nc.Close() 6001 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) 6002 // Switch to a little protocol that capitalize its input lines: 6003 for { 6004 line, err := brw.ReadString('\n') 6005 if err != nil { 6006 if err != io.EOF { 6007 t.Error(err) 6008 } 6009 return 6010 } 6011 io.WriteString(brw, strings.ToUpper(line)) 6012 brw.Flush() 6013 } 6014 })) 6015 pr, pw := io.Pipe() 6016 defer pw.Close() 6017 req, err := NewRequest("CONNECT", cst.ts.URL, pr) 6018 if err != nil { 6019 t.Fatal(err) 6020 } 6021 req.URL.Opaque = target 6022 res, err := cst.c.Do(req) 6023 if err != nil { 6024 t.Fatal(err) 6025 } 6026 defer res.Body.Close() 6027 if res.StatusCode != 200 { 6028 t.Fatalf("status code = %d; want 200", res.StatusCode) 6029 } 6030 br := bufio.NewReader(res.Body) 6031 for _, str := range []string{"foo", "bar", "baz"} { 6032 fmt.Fprintf(pw, "%s\n", str) 6033 got, err := br.ReadString('\n') 6034 if err != nil { 6035 t.Fatal(err) 6036 } 6037 got = strings.TrimSpace(got) 6038 want := strings.ToUpper(str) 6039 if got != want { 6040 t.Fatalf("got %q; want %q", got, want) 6041 } 6042 } 6043 } 6044 6045 func TestTransportRequestReplayable(t *testing.T) { 6046 someBody := io.NopCloser(strings.NewReader("")) 6047 tests := []struct { 6048 name string 6049 req *Request 6050 want bool 6051 }{ 6052 { 6053 name: "GET", 6054 req: &Request{Method: "GET"}, 6055 want: true, 6056 }, 6057 { 6058 name: "GET_http.NoBody", 6059 req: &Request{Method: "GET", Body: NoBody}, 6060 want: true, 6061 }, 6062 { 6063 name: "GET_body", 6064 req: &Request{Method: "GET", Body: someBody}, 6065 want: false, 6066 }, 6067 { 6068 name: "POST", 6069 req: &Request{Method: "POST"}, 6070 want: false, 6071 }, 6072 { 6073 name: "POST_idempotency-key", 6074 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}}, 6075 want: true, 6076 }, 6077 { 6078 name: "POST_x-idempotency-key", 6079 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}}, 6080 want: true, 6081 }, 6082 { 6083 name: "POST_body", 6084 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody}, 6085 want: false, 6086 }, 6087 } 6088 for _, tt := range tests { 6089 t.Run(tt.name, func(t *testing.T) { 6090 got := tt.req.ExportIsReplayable() 6091 if got != tt.want { 6092 t.Errorf("replyable = %v; want %v", got, tt.want) 6093 } 6094 }) 6095 } 6096 } 6097 6098 // testMockTCPConn is a mock TCP connection used to test that 6099 // ReadFrom is called when sending the request body. 6100 type testMockTCPConn struct { 6101 *net.TCPConn 6102 6103 ReadFromCalled bool 6104 } 6105 6106 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { 6107 c.ReadFromCalled = true 6108 return c.TCPConn.ReadFrom(r) 6109 } 6110 6111 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) } 6112 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) { 6113 nBytes := int64(1 << 10) 6114 newFileFunc := func() (r io.Reader, done func(), err error) { 6115 f, err := os.CreateTemp("", "net-http-newfilefunc") 6116 if err != nil { 6117 return nil, nil, err 6118 } 6119 6120 // Write some bytes to the file to enable reading. 6121 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { 6122 return nil, nil, fmt.Errorf("failed to write data to file: %v", err) 6123 } 6124 if _, err := f.Seek(0, 0); err != nil { 6125 return nil, nil, fmt.Errorf("failed to seek to front: %v", err) 6126 } 6127 6128 done = func() { 6129 f.Close() 6130 os.Remove(f.Name()) 6131 } 6132 6133 return f, done, nil 6134 } 6135 6136 newBufferFunc := func() (io.Reader, func(), error) { 6137 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil 6138 } 6139 6140 cases := []struct { 6141 name string 6142 readerFunc func() (io.Reader, func(), error) 6143 contentLength int64 6144 expectedReadFrom bool 6145 }{ 6146 { 6147 name: "file, length", 6148 readerFunc: newFileFunc, 6149 contentLength: nBytes, 6150 expectedReadFrom: true, 6151 }, 6152 { 6153 name: "file, no length", 6154 readerFunc: newFileFunc, 6155 }, 6156 { 6157 name: "file, negative length", 6158 readerFunc: newFileFunc, 6159 contentLength: -1, 6160 }, 6161 { 6162 name: "buffer", 6163 contentLength: nBytes, 6164 readerFunc: newBufferFunc, 6165 }, 6166 { 6167 name: "buffer, no length", 6168 readerFunc: newBufferFunc, 6169 }, 6170 { 6171 name: "buffer, length -1", 6172 contentLength: -1, 6173 readerFunc: newBufferFunc, 6174 }, 6175 } 6176 6177 for _, tc := range cases { 6178 t.Run(tc.name, func(t *testing.T) { 6179 r, cleanup, err := tc.readerFunc() 6180 if err != nil { 6181 t.Fatal(err) 6182 } 6183 defer cleanup() 6184 6185 tConn := &testMockTCPConn{} 6186 trFunc := func(tr *Transport) { 6187 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { 6188 var d net.Dialer 6189 conn, err := d.DialContext(ctx, network, addr) 6190 if err != nil { 6191 return nil, err 6192 } 6193 6194 tcpConn, ok := conn.(*net.TCPConn) 6195 if !ok { 6196 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr) 6197 } 6198 6199 tConn.TCPConn = tcpConn 6200 return tConn, nil 6201 } 6202 } 6203 6204 cst := newClientServerTest( 6205 t, 6206 mode, 6207 HandlerFunc(func(w ResponseWriter, r *Request) { 6208 io.Copy(io.Discard, r.Body) 6209 r.Body.Close() 6210 w.WriteHeader(200) 6211 }), 6212 trFunc, 6213 ) 6214 6215 req, err := NewRequest("PUT", cst.ts.URL, r) 6216 if err != nil { 6217 t.Fatal(err) 6218 } 6219 req.ContentLength = tc.contentLength 6220 req.Header.Set("Content-Type", "application/octet-stream") 6221 resp, err := cst.c.Do(req) 6222 if err != nil { 6223 t.Fatal(err) 6224 } 6225 defer resp.Body.Close() 6226 if resp.StatusCode != 200 { 6227 t.Fatalf("status code = %d; want 200", resp.StatusCode) 6228 } 6229 6230 expectedReadFrom := tc.expectedReadFrom 6231 if mode != http1Mode { 6232 expectedReadFrom = false 6233 } 6234 if !tConn.ReadFromCalled && expectedReadFrom { 6235 t.Fatalf("did not call ReadFrom") 6236 } 6237 6238 if tConn.ReadFromCalled && !expectedReadFrom { 6239 t.Fatalf("ReadFrom was unexpectedly invoked") 6240 } 6241 }) 6242 } 6243 } 6244 6245 func TestTransportClone(t *testing.T) { 6246 tr := &Transport{ 6247 Proxy: func(*Request) (*url.URL, error) { panic("") }, 6248 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error { 6249 return nil 6250 }, 6251 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, 6252 Dial: func(network, addr string) (net.Conn, error) { panic("") }, 6253 DialTLS: func(network, addr string) (net.Conn, error) { panic("") }, 6254 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, 6255 TLSClientConfig: new(tls.Config), 6256 TLSHandshakeTimeout: time.Second, 6257 DisableKeepAlives: true, 6258 DisableCompression: true, 6259 MaxIdleConns: 1, 6260 MaxIdleConnsPerHost: 1, 6261 MaxConnsPerHost: 1, 6262 IdleConnTimeout: time.Second, 6263 ResponseHeaderTimeout: time.Second, 6264 ExpectContinueTimeout: time.Second, 6265 ProxyConnectHeader: Header{}, 6266 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil }, 6267 MaxResponseHeaderBytes: 1, 6268 ForceAttemptHTTP2: true, 6269 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{ 6270 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") }, 6271 }, 6272 ReadBufferSize: 1, 6273 WriteBufferSize: 1, 6274 } 6275 tr2 := tr.Clone() 6276 rv := reflect.ValueOf(tr2).Elem() 6277 rt := rv.Type() 6278 for i := 0; i < rt.NumField(); i++ { 6279 sf := rt.Field(i) 6280 if !token.IsExported(sf.Name) { 6281 continue 6282 } 6283 if rv.Field(i).IsZero() { 6284 t.Errorf("cloned field t2.%s is zero", sf.Name) 6285 } 6286 } 6287 6288 if _, ok := tr2.TLSNextProto["foo"]; !ok { 6289 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key") 6290 } 6291 6292 // But test that a nil TLSNextProto is kept nil: 6293 tr = new(Transport) 6294 tr2 = tr.Clone() 6295 if tr2.TLSNextProto != nil { 6296 t.Errorf("Transport.TLSNextProto unexpected non-nil") 6297 } 6298 } 6299 6300 func TestIs408(t *testing.T) { 6301 tests := []struct { 6302 in string 6303 want bool 6304 }{ 6305 {"HTTP/1.0 408", true}, 6306 {"HTTP/1.1 408", true}, 6307 {"HTTP/1.8 408", true}, 6308 {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now. 6309 {"HTTP/1.1 408 ", true}, 6310 {"HTTP/1.1 40", false}, 6311 {"http/1.0 408", false}, 6312 {"HTTP/1-1 408", false}, 6313 } 6314 for _, tt := range tests { 6315 if got := Export_is408Message([]byte(tt.in)); got != tt.want { 6316 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want) 6317 } 6318 } 6319 } 6320 6321 func TestTransportIgnores408(t *testing.T) { 6322 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel) 6323 } 6324 func testTransportIgnores408(t *testing.T, mode testMode) { 6325 // Not parallel. Relies on mutating the log package's global Output. 6326 defer log.SetOutput(log.Writer()) 6327 6328 var logout strings.Builder 6329 log.SetOutput(&logout) 6330 6331 const target = "backend:443" 6332 6333 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 6334 nc, _, err := w.(Hijacker).Hijack() 6335 if err != nil { 6336 t.Error(err) 6337 return 6338 } 6339 defer nc.Close() 6340 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) 6341 nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail 6342 })) 6343 req, err := NewRequest("GET", cst.ts.URL, nil) 6344 if err != nil { 6345 t.Fatal(err) 6346 } 6347 res, err := cst.c.Do(req) 6348 if err != nil { 6349 t.Fatal(err) 6350 } 6351 slurp, err := io.ReadAll(res.Body) 6352 if err != nil { 6353 t.Fatal(err) 6354 } 6355 if err != nil { 6356 t.Fatal(err) 6357 } 6358 if string(slurp) != "ok" { 6359 t.Fatalf("got %q; want ok", slurp) 6360 } 6361 6362 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool { 6363 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 { 6364 if d > 0 { 6365 t.Logf("%v idle conns still present after %v", n, d) 6366 } 6367 return false 6368 } 6369 return true 6370 }) 6371 if got := logout.String(); got != "" { 6372 t.Fatalf("expected no log output; got: %s", got) 6373 } 6374 } 6375 6376 func TestInvalidHeaderResponse(t *testing.T) { 6377 run(t, testInvalidHeaderResponse, []testMode{http1Mode}) 6378 } 6379 func testInvalidHeaderResponse(t *testing.T, mode testMode) { 6380 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 6381 conn, buf, _ := w.(Hijacker).Hijack() 6382 buf.Write([]byte("HTTP/1.1 200 OK\r\n" + 6383 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + 6384 "Content-Type: text/html; charset=utf-8\r\n" + 6385 "Content-Length: 0\r\n" + 6386 "Foo : bar\r\n\r\n")) 6387 buf.Flush() 6388 conn.Close() 6389 })) 6390 res, err := cst.c.Get(cst.ts.URL) 6391 if err != nil { 6392 t.Fatal(err) 6393 } 6394 defer res.Body.Close() 6395 if v := res.Header.Get("Foo"); v != "" { 6396 t.Errorf(`unexpected "Foo" header: %q`, v) 6397 } 6398 if v := res.Header.Get("Foo "); v != "bar" { 6399 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar") 6400 } 6401 } 6402 6403 type bodyCloser bool 6404 6405 func (bc *bodyCloser) Close() error { 6406 *bc = true 6407 return nil 6408 } 6409 func (bc *bodyCloser) Read(b []byte) (n int, err error) { 6410 return 0, io.EOF 6411 } 6412 6413 // Issue 35015: ensure that Transport closes the body on any error 6414 // with an invalid request, as promised by Client.Do docs. 6415 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { 6416 run(t, testTransportClosesBodyOnInvalidRequests) 6417 } 6418 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) { 6419 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 6420 t.Errorf("Should not have been invoked") 6421 })).ts 6422 6423 u, _ := url.Parse(cst.URL) 6424 6425 tests := []struct { 6426 name string 6427 req *Request 6428 wantErr string 6429 }{ 6430 { 6431 name: "invalid method", 6432 req: &Request{ 6433 Method: " ", 6434 URL: u, 6435 }, 6436 wantErr: `invalid method " "`, 6437 }, 6438 { 6439 name: "nil URL", 6440 req: &Request{ 6441 Method: "GET", 6442 }, 6443 wantErr: `nil Request.URL`, 6444 }, 6445 { 6446 name: "invalid header key", 6447 req: &Request{ 6448 Method: "GET", 6449 Header: Header{"💡": {"emoji"}}, 6450 URL: u, 6451 }, 6452 wantErr: `invalid header field name "💡"`, 6453 }, 6454 { 6455 name: "invalid header value", 6456 req: &Request{ 6457 Method: "POST", 6458 Header: Header{"key": {"\x19"}}, 6459 URL: u, 6460 }, 6461 wantErr: `invalid header field value for "key"`, 6462 }, 6463 { 6464 name: "non HTTP(s) scheme", 6465 req: &Request{ 6466 Method: "POST", 6467 URL: &url.URL{Scheme: "faux"}, 6468 }, 6469 wantErr: `unsupported protocol scheme "faux"`, 6470 }, 6471 { 6472 name: "no Host in URL", 6473 req: &Request{ 6474 Method: "POST", 6475 URL: &url.URL{Scheme: "http"}, 6476 }, 6477 wantErr: `no Host in request URL`, 6478 }, 6479 } 6480 6481 for _, tt := range tests { 6482 t.Run(tt.name, func(t *testing.T) { 6483 var bc bodyCloser 6484 req := tt.req 6485 req.Body = &bc 6486 _, err := cst.Client().Do(tt.req) 6487 if err == nil { 6488 t.Fatal("Expected an error") 6489 } 6490 if !bc { 6491 t.Fatal("Expected body to have been closed") 6492 } 6493 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) { 6494 t.Fatalf("Error mismatch: %q does not end with %q", g, w) 6495 } 6496 }) 6497 } 6498 } 6499 6500 // breakableConn is a net.Conn wrapper with a Write method 6501 // that will fail when its brokenState is true. 6502 type breakableConn struct { 6503 net.Conn 6504 *brokenState 6505 } 6506 6507 type brokenState struct { 6508 sync.Mutex 6509 broken bool 6510 } 6511 6512 func (w *breakableConn) Write(b []byte) (n int, err error) { 6513 w.Lock() 6514 defer w.Unlock() 6515 if w.broken { 6516 return 0, errors.New("some write error") 6517 } 6518 return w.Conn.Write(b) 6519 } 6520 6521 // Issue 34978: don't cache a broken HTTP/2 connection 6522 func TestDontCacheBrokenHTTP2Conn(t *testing.T) { 6523 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode}) 6524 } 6525 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) { 6526 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog) 6527 6528 var brokenState brokenState 6529 6530 const numReqs = 5 6531 var numDials, gotConns uint32 // atomic 6532 6533 cst.tr.Dial = func(netw, addr string) (net.Conn, error) { 6534 atomic.AddUint32(&numDials, 1) 6535 c, err := net.Dial(netw, addr) 6536 if err != nil { 6537 t.Errorf("unexpected Dial error: %v", err) 6538 return nil, err 6539 } 6540 return &breakableConn{c, &brokenState}, err 6541 } 6542 6543 for i := 1; i <= numReqs; i++ { 6544 brokenState.Lock() 6545 brokenState.broken = false 6546 brokenState.Unlock() 6547 6548 // doBreak controls whether we break the TCP connection after the TLS 6549 // handshake (before the HTTP/2 handshake). We test a few failures 6550 // in a row followed by a final success. 6551 doBreak := i != numReqs 6552 6553 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ 6554 GotConn: func(info httptrace.GotConnInfo) { 6555 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime) 6556 atomic.AddUint32(&gotConns, 1) 6557 }, 6558 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) { 6559 brokenState.Lock() 6560 defer brokenState.Unlock() 6561 if doBreak { 6562 brokenState.broken = true 6563 } 6564 }, 6565 }) 6566 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) 6567 if err != nil { 6568 t.Fatal(err) 6569 } 6570 _, err = cst.c.Do(req) 6571 if doBreak != (err != nil) { 6572 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err) 6573 } 6574 } 6575 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want { 6576 t.Errorf("GotConn calls = %v; want %v", got, want) 6577 } 6578 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want { 6579 t.Errorf("Dials = %v; want %v", got, want) 6580 } 6581 } 6582 6583 // Issue 34941 6584 // When the client has too many concurrent requests on a single connection, 6585 // http.http2noCachedConnError is reported on multiple requests. There should 6586 // only be one decrement regardless of the number of failures. 6587 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { 6588 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode}) 6589 } 6590 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) { 6591 CondSkipHTTP2(t) 6592 6593 h := HandlerFunc(func(w ResponseWriter, r *Request) { 6594 _, err := w.Write([]byte("foo")) 6595 if err != nil { 6596 t.Fatalf("Write: %v", err) 6597 } 6598 }) 6599 6600 ts := newClientServerTest(t, mode, h).ts 6601 6602 c := ts.Client() 6603 tr := c.Transport.(*Transport) 6604 tr.MaxConnsPerHost = 1 6605 6606 errCh := make(chan error, 300) 6607 doReq := func() { 6608 resp, err := c.Get(ts.URL) 6609 if err != nil { 6610 errCh <- fmt.Errorf("request failed: %v", err) 6611 return 6612 } 6613 defer resp.Body.Close() 6614 _, err = io.ReadAll(resp.Body) 6615 if err != nil { 6616 errCh <- fmt.Errorf("read body failed: %v", err) 6617 } 6618 } 6619 6620 var wg sync.WaitGroup 6621 for i := 0; i < 300; i++ { 6622 wg.Add(1) 6623 go func() { 6624 defer wg.Done() 6625 doReq() 6626 }() 6627 } 6628 wg.Wait() 6629 close(errCh) 6630 6631 for err := range errCh { 6632 t.Errorf("error occurred: %v", err) 6633 } 6634 } 6635 6636 // Issue 36820 6637 // Test that we use the older backward compatible cancellation protocol 6638 // when a RoundTripper is registered via RegisterProtocol. 6639 func TestAltProtoCancellation(t *testing.T) { 6640 defer afterTest(t) 6641 tr := &Transport{} 6642 c := &Client{ 6643 Transport: tr, 6644 Timeout: time.Millisecond, 6645 } 6646 tr.RegisterProtocol("cancel", cancelProto{}) 6647 _, err := c.Get("cancel://bar.com/path") 6648 if err == nil { 6649 t.Error("request unexpectedly succeeded") 6650 } else if !strings.Contains(err.Error(), errCancelProto.Error()) { 6651 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto) 6652 } 6653 } 6654 6655 var errCancelProto = errors.New("canceled as expected") 6656 6657 type cancelProto struct{} 6658 6659 func (cancelProto) RoundTrip(req *Request) (*Response, error) { 6660 <-req.Cancel 6661 return nil, errCancelProto 6662 } 6663 6664 type roundTripFunc func(r *Request) (*Response, error) 6665 6666 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } 6667 6668 // Issue 32441: body is not reset after ErrSkipAltProtocol 6669 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) } 6670 func testIssue32441(t *testing.T, mode testMode) { 6671 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 6672 if n, _ := io.Copy(io.Discard, r.Body); n == 0 { 6673 t.Error("body length is zero") 6674 } 6675 })).ts 6676 c := ts.Client() 6677 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) { 6678 // Draining body to trigger failure condition on actual request to server. 6679 if n, _ := io.Copy(io.Discard, r.Body); n == 0 { 6680 t.Error("body length is zero during round trip") 6681 } 6682 return nil, ErrSkipAltProtocol 6683 })) 6684 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil { 6685 t.Error(err) 6686 } 6687 } 6688 6689 // Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers 6690 // that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. 6691 func TestTransportRejectsSignInContentLength(t *testing.T) { 6692 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode}) 6693 } 6694 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) { 6695 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 6696 w.Header().Set("Content-Length", "+3") 6697 w.Write([]byte("abc")) 6698 })).ts 6699 6700 c := cst.Client() 6701 res, err := c.Get(cst.URL) 6702 if err == nil || res != nil { 6703 t.Fatal("Expected a non-nil error and a nil http.Response") 6704 } 6705 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) { 6706 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want) 6707 } 6708 } 6709 6710 // dumpConn is a net.Conn which writes to Writer and reads from Reader 6711 type dumpConn struct { 6712 io.Writer 6713 io.Reader 6714 } 6715 6716 func (c *dumpConn) Close() error { return nil } 6717 func (c *dumpConn) LocalAddr() net.Addr { return nil } 6718 func (c *dumpConn) RemoteAddr() net.Addr { return nil } 6719 func (c *dumpConn) SetDeadline(t time.Time) error { return nil } 6720 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } 6721 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } 6722 6723 // delegateReader is a reader that delegates to another reader, 6724 // once it arrives on a channel. 6725 type delegateReader struct { 6726 c chan io.Reader 6727 r io.Reader // nil until received from c 6728 } 6729 6730 func (r *delegateReader) Read(p []byte) (int, error) { 6731 if r.r == nil { 6732 var ok bool 6733 if r.r, ok = <-r.c; !ok { 6734 return 0, errors.New("delegate closed") 6735 } 6736 } 6737 return r.r.Read(p) 6738 } 6739 6740 func testTransportRace(req *Request) { 6741 save := req.Body 6742 pr, pw := io.Pipe() 6743 defer pr.Close() 6744 defer pw.Close() 6745 dr := &delegateReader{c: make(chan io.Reader)} 6746 6747 t := &Transport{ 6748 Dial: func(net, addr string) (net.Conn, error) { 6749 return &dumpConn{pw, dr}, nil 6750 }, 6751 } 6752 defer t.CloseIdleConnections() 6753 6754 quitReadCh := make(chan struct{}) 6755 // Wait for the request before replying with a dummy response: 6756 go func() { 6757 defer close(quitReadCh) 6758 6759 req, err := ReadRequest(bufio.NewReader(pr)) 6760 if err == nil { 6761 // Ensure all the body is read; otherwise 6762 // we'll get a partial dump. 6763 io.Copy(io.Discard, req.Body) 6764 req.Body.Close() 6765 } 6766 select { 6767 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"): 6768 case quitReadCh <- struct{}{}: 6769 // Ensure delegate is closed so Read doesn't block forever. 6770 close(dr.c) 6771 } 6772 }() 6773 6774 t.RoundTrip(req) 6775 6776 // Ensure the reader returns before we reset req.Body to prevent 6777 // a data race on req.Body. 6778 pw.Close() 6779 <-quitReadCh 6780 6781 req.Body = save 6782 } 6783 6784 // Issue 37669 6785 // Test that a cancellation doesn't result in a data race due to the writeLoop 6786 // goroutine being left running, if the caller mutates the processed Request 6787 // upon completion. 6788 func TestErrorWriteLoopRace(t *testing.T) { 6789 if testing.Short() { 6790 return 6791 } 6792 t.Parallel() 6793 for i := 0; i < 1000; i++ { 6794 delay := time.Duration(mrand.Intn(5)) * time.Millisecond 6795 ctx, cancel := context.WithTimeout(context.Background(), delay) 6796 defer cancel() 6797 6798 r := bytes.NewBuffer(make([]byte, 10000)) 6799 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r) 6800 if err != nil { 6801 t.Fatal(err) 6802 } 6803 6804 testTransportRace(req) 6805 } 6806 } 6807 6808 // Issue 41600 6809 // Test that a new request which uses the connection of an active request 6810 // cannot cause it to be canceled as well. 6811 func TestCancelRequestWhenSharingConnection(t *testing.T) { 6812 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode}) 6813 } 6814 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) { 6815 reqc := make(chan chan struct{}, 2) 6816 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { 6817 ch := make(chan struct{}, 1) 6818 reqc <- ch 6819 <-ch 6820 w.Header().Add("Content-Length", "0") 6821 })).ts 6822 6823 client := ts.Client() 6824 transport := client.Transport.(*Transport) 6825 transport.MaxIdleConns = 1 6826 transport.MaxConnsPerHost = 1 6827 6828 var wg sync.WaitGroup 6829 6830 wg.Add(1) 6831 putidlec := make(chan chan struct{}, 1) 6832 reqerrc := make(chan error, 1) 6833 go func() { 6834 defer wg.Done() 6835 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ 6836 PutIdleConn: func(error) { 6837 // Signal that the idle conn has been returned to the pool, 6838 // and wait for the order to proceed. 6839 ch := make(chan struct{}) 6840 putidlec <- ch 6841 close(putidlec) // panic if PutIdleConn runs twice for some reason 6842 <-ch 6843 }, 6844 }) 6845 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil) 6846 res, err := client.Do(req) 6847 if err != nil { 6848 reqerrc <- err 6849 } else { 6850 res.Body.Close() 6851 } 6852 }() 6853 6854 // Wait for the first request to receive a response and return the 6855 // connection to the idle pool. 6856 select { 6857 case err := <-reqerrc: 6858 t.Fatalf("request 1: got err %v, want nil", err) 6859 case r1c := <-reqc: 6860 close(r1c) 6861 } 6862 var idlec chan struct{} 6863 select { 6864 case err := <-reqerrc: 6865 t.Fatalf("request 1: got err %v, want nil", err) 6866 case idlec = <-putidlec: 6867 } 6868 6869 wg.Add(1) 6870 cancelctx, cancel := context.WithCancel(context.Background()) 6871 go func() { 6872 defer wg.Done() 6873 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil) 6874 res, err := client.Do(req) 6875 if err == nil { 6876 res.Body.Close() 6877 } 6878 if !errors.Is(err, context.Canceled) { 6879 t.Errorf("request 2: got err %v, want Canceled", err) 6880 } 6881 6882 // Unblock the first request. 6883 close(idlec) 6884 }() 6885 6886 // Wait for the second request to arrive at the server, and then cancel 6887 // the request context. 6888 r2c := <-reqc 6889 cancel() 6890 6891 <-idlec 6892 6893 close(r2c) 6894 wg.Wait() 6895 } 6896 6897 func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) } 6898 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) { 6899 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { 6900 go io.Copy(io.Discard, req.Body) 6901 panic(ErrAbortHandler) 6902 })).ts 6903 6904 var wg sync.WaitGroup 6905 for i := 0; i < 2; i++ { 6906 wg.Add(1) 6907 go func() { 6908 defer wg.Done() 6909 for j := 0; j < 10; j++ { 6910 const reqLen = 6 * 1024 * 1024 6911 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) 6912 req.ContentLength = reqLen 6913 resp, _ := ts.Client().Transport.RoundTrip(req) 6914 if resp != nil { 6915 resp.Body.Close() 6916 } 6917 } 6918 }() 6919 } 6920 wg.Wait() 6921 } 6922 6923 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) } 6924 func testRequestSanitization(t *testing.T, mode testMode) { 6925 if mode == http2Mode { 6926 // Remove this after updating x/net. 6927 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2") 6928 } 6929 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { 6930 if h, ok := req.Header["X-Evil"]; ok { 6931 t.Errorf("request has X-Evil header: %q", h) 6932 } 6933 })).ts 6934 req, _ := NewRequest("GET", ts.URL, nil) 6935 req.Host = "go.dev\r\nX-Evil:evil" 6936 resp, _ := ts.Client().Do(req) 6937 if resp != nil { 6938 resp.Body.Close() 6939 } 6940 } 6941 6942 func TestProxyAuthHeader(t *testing.T) { 6943 // Not parallel: Sets an environment variable. 6944 run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel) 6945 } 6946 func testProxyAuthHeader(t *testing.T, mode testMode) { 6947 const username = "u" 6948 const password = "@/?!" 6949 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { 6950 // Copy the Proxy-Authorization header to a new Request, 6951 // since Request.BasicAuth only parses the Authorization header. 6952 var r2 Request 6953 r2.Header = Header{ 6954 "Authorization": req.Header["Proxy-Authorization"], 6955 } 6956 gotuser, gotpass, ok := r2.BasicAuth() 6957 if !ok || gotuser != username || gotpass != password { 6958 t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password) 6959 } 6960 })) 6961 u, err := url.Parse(cst.ts.URL) 6962 if err != nil { 6963 t.Fatal(err) 6964 } 6965 u.User = url.UserPassword(username, password) 6966 t.Setenv("HTTP_PROXY", u.String()) 6967 cst.tr.Proxy = ProxyURL(u) 6968 resp, err := cst.c.Get("http://_/") 6969 if err != nil { 6970 t.Fatal(err) 6971 } 6972 resp.Body.Close() 6973 } 6974 6975 // Issue 61708 6976 func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) { 6977 ln := newLocalListener(t) 6978 addr := ln.Addr().String() 6979 6980 done := make(chan struct{}) 6981 go func() { 6982 conn, err := ln.Accept() 6983 if err != nil { 6984 t.Errorf("ln.Accept: %v", err) 6985 return 6986 } 6987 // Start reading request before sending response to avoid 6988 // "Unsolicited response received on idle HTTP channel" RoundTrip error. 6989 if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil { 6990 t.Errorf("conn.Read: %v", err) 6991 return 6992 } 6993 io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo") 6994 <-done 6995 conn.Close() 6996 }() 6997 6998 didRead := make(chan bool) 6999 SetReadLoopBeforeNextReadHook(func() { didRead <- true }) 7000 defer SetReadLoopBeforeNextReadHook(nil) 7001 7002 tr := &Transport{} 7003 7004 // Send a request with a body guaranteed to fail on write. 7005 req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30)) 7006 if err != nil { 7007 t.Fatalf("NewRequest: %v", err) 7008 } 7009 7010 resp, err := tr.RoundTrip(req) 7011 if err != nil { 7012 t.Fatalf("tr.RoundTrip: %v", err) 7013 } 7014 7015 close(done) 7016 7017 // Before closing response body wait for readLoopDone goroutine 7018 // to complete due to closed connection by writeLoop. 7019 <-didRead 7020 7021 resp.Body.Close() 7022 7023 // Verify no outstanding requests after readLoop/writeLoop 7024 // goroutines shut down. 7025 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { 7026 n := tr.NumPendingRequestsForTesting() 7027 if n > 0 { 7028 if d > 0 { 7029 t.Logf("pending requests = %d after %v (want 0)", n, d) 7030 } 7031 return false 7032 } 7033 return true 7034 }) 7035 } 7036 7037 func TestValidateClientRequestTrailers(t *testing.T) { 7038 run(t, testValidateClientRequestTrailers) 7039 } 7040 7041 func testValidateClientRequestTrailers(t *testing.T, mode testMode) { 7042 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { 7043 rw.Write([]byte("Hello")) 7044 })).ts 7045 7046 cases := []struct { 7047 trailer Header 7048 wantErr string 7049 }{ 7050 {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`}, 7051 {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`}, 7052 } 7053 7054 for i, tt := range cases { 7055 testName := fmt.Sprintf("%s%d", mode, i) 7056 t.Run(testName, func(t *testing.T) { 7057 req, err := NewRequest("GET", cst.URL, nil) 7058 if err != nil { 7059 t.Fatal(err) 7060 } 7061 req.Trailer = tt.trailer 7062 res, err := cst.Client().Do(req) 7063 if err == nil { 7064 t.Fatal("Expected an error") 7065 } 7066 if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) { 7067 t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w) 7068 } 7069 if res != nil { 7070 t.Fatal("Unexpected non-nil response") 7071 } 7072 }) 7073 } 7074 }