github.com/ice-blockchain/go/src@v0.0.0-20240403114104-1564d284e521/net/http/clientserver_test.go (about) 1 // Copyright 2015 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode. 6 7 package http_test 8 9 import ( 10 "bytes" 11 "compress/gzip" 12 "context" 13 "crypto/rand" 14 "crypto/sha1" 15 "crypto/tls" 16 "fmt" 17 "hash" 18 "io" 19 "log" 20 "net" 21 . "net/http" 22 "net/http/httptest" 23 "net/http/httptrace" 24 "net/http/httputil" 25 "net/textproto" 26 "net/url" 27 "os" 28 "reflect" 29 "runtime" 30 "sort" 31 "strings" 32 "sync" 33 "sync/atomic" 34 "testing" 35 "time" 36 ) 37 38 type testMode string 39 40 const ( 41 http1Mode = testMode("h1") // HTTP/1.1 42 https1Mode = testMode("https1") // HTTPS/1.1 43 http2Mode = testMode("h2") // HTTP/2 44 ) 45 46 type testNotParallelOpt struct{} 47 48 var ( 49 testNotParallel = testNotParallelOpt{} 50 ) 51 52 type TBRun[T any] interface { 53 testing.TB 54 Run(string, func(T)) bool 55 } 56 57 // run runs a client/server test in a variety of test configurations. 58 // 59 // Tests execute in HTTP/1.1 and HTTP/2 modes by default. 60 // To run in a different set of configurations, pass a []testMode option. 61 // 62 // Tests call t.Parallel() by default. 63 // To disable parallel execution, pass the testNotParallel option. 64 func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) { 65 t.Helper() 66 modes := []testMode{http1Mode, http2Mode} 67 parallel := true 68 for _, opt := range opts { 69 switch opt := opt.(type) { 70 case []testMode: 71 modes = opt 72 case testNotParallelOpt: 73 parallel = false 74 default: 75 t.Fatalf("unknown option type %T", opt) 76 } 77 } 78 if t, ok := any(t).(*testing.T); ok && parallel { 79 setParallel(t) 80 } 81 for _, mode := range modes { 82 t.Run(string(mode), func(t T) { 83 t.Helper() 84 if t, ok := any(t).(*testing.T); ok && parallel { 85 setParallel(t) 86 } 87 t.Cleanup(func() { 88 afterTest(t) 89 }) 90 f(t, mode) 91 }) 92 } 93 } 94 95 type clientServerTest struct { 96 t testing.TB 97 h2 bool 98 h Handler 99 ts *httptest.Server 100 tr *Transport 101 c *Client 102 } 103 104 func (t *clientServerTest) close() { 105 t.tr.CloseIdleConnections() 106 t.ts.Close() 107 } 108 109 func (t *clientServerTest) getURL(u string) string { 110 res, err := t.c.Get(u) 111 if err != nil { 112 t.t.Fatal(err) 113 } 114 defer res.Body.Close() 115 slurp, err := io.ReadAll(res.Body) 116 if err != nil { 117 t.t.Fatal(err) 118 } 119 return string(slurp) 120 } 121 122 func (t *clientServerTest) scheme() string { 123 if t.h2 { 124 return "https" 125 } 126 return "http" 127 } 128 129 var optQuietLog = func(ts *httptest.Server) { 130 ts.Config.ErrorLog = quietLog 131 } 132 133 func optWithServerLog(lg *log.Logger) func(*httptest.Server) { 134 return func(ts *httptest.Server) { 135 ts.Config.ErrorLog = lg 136 } 137 } 138 139 // newClientServerTest creates and starts an httptest.Server. 140 // 141 // The mode parameter selects the implementation to test: 142 // HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use 143 // the 'run' function, which will start a subtests for each tested mode. 144 // 145 // The vararg opts parameter can include functions to configure the 146 // test server or transport. 147 // 148 // func(*httptest.Server) // run before starting the server 149 // func(*http.Transport) 150 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest { 151 if mode == http2Mode { 152 CondSkipHTTP2(t) 153 } 154 cst := &clientServerTest{ 155 t: t, 156 h2: mode == http2Mode, 157 h: h, 158 } 159 cst.ts = httptest.NewUnstartedServer(h) 160 161 var transportFuncs []func(*Transport) 162 for _, opt := range opts { 163 switch opt := opt.(type) { 164 case func(*Transport): 165 transportFuncs = append(transportFuncs, opt) 166 case func(*httptest.Server): 167 opt(cst.ts) 168 default: 169 t.Fatalf("unhandled option type %T", opt) 170 } 171 } 172 173 if cst.ts.Config.ErrorLog == nil { 174 cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0) 175 } 176 177 switch mode { 178 case http1Mode: 179 cst.ts.Start() 180 case https1Mode: 181 cst.ts.StartTLS() 182 case http2Mode: 183 ExportHttp2ConfigureServer(cst.ts.Config, nil) 184 cst.ts.TLS = cst.ts.Config.TLSConfig 185 cst.ts.StartTLS() 186 default: 187 t.Fatalf("unknown test mode %v", mode) 188 } 189 cst.c = cst.ts.Client() 190 cst.tr = cst.c.Transport.(*Transport) 191 if mode == http2Mode { 192 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { 193 t.Fatal(err) 194 } 195 } 196 for _, f := range transportFuncs { 197 f(cst.tr) 198 } 199 t.Cleanup(func() { 200 cst.close() 201 }) 202 return cst 203 } 204 205 type testLogWriter struct { 206 t testing.TB 207 } 208 209 func (w testLogWriter) Write(b []byte) (int, error) { 210 w.t.Logf("server log: %v", strings.TrimSpace(string(b))) 211 return len(b), nil 212 } 213 214 // Testing the newClientServerTest helper itself. 215 func TestNewClientServerTest(t *testing.T) { 216 run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode}) 217 } 218 func testNewClientServerTest(t *testing.T, mode testMode) { 219 var got struct { 220 sync.Mutex 221 proto string 222 hasTLS bool 223 } 224 h := HandlerFunc(func(w ResponseWriter, r *Request) { 225 got.Lock() 226 defer got.Unlock() 227 got.proto = r.Proto 228 got.hasTLS = r.TLS != nil 229 }) 230 cst := newClientServerTest(t, mode, h) 231 if _, err := cst.c.Head(cst.ts.URL); err != nil { 232 t.Fatal(err) 233 } 234 var wantProto string 235 var wantTLS bool 236 switch mode { 237 case http1Mode: 238 wantProto = "HTTP/1.1" 239 wantTLS = false 240 case https1Mode: 241 wantProto = "HTTP/1.1" 242 wantTLS = true 243 case http2Mode: 244 wantProto = "HTTP/2.0" 245 wantTLS = true 246 } 247 if got.proto != wantProto { 248 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto) 249 } 250 if got.hasTLS != wantTLS { 251 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS) 252 } 253 } 254 255 func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) } 256 func testChunkedResponseHeaders(t *testing.T, mode testMode) { 257 log.SetOutput(io.Discard) // is noisy otherwise 258 defer log.SetOutput(os.Stderr) 259 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 260 w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted 261 w.(Flusher).Flush() 262 fmt.Fprintf(w, "I am a chunked response.") 263 })) 264 265 res, err := cst.c.Get(cst.ts.URL) 266 if err != nil { 267 t.Fatalf("Get error: %v", err) 268 } 269 defer res.Body.Close() 270 if g, e := res.ContentLength, int64(-1); g != e { 271 t.Errorf("expected ContentLength of %d; got %d", e, g) 272 } 273 wantTE := []string{"chunked"} 274 if mode == http2Mode { 275 wantTE = nil 276 } 277 if !reflect.DeepEqual(res.TransferEncoding, wantTE) { 278 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE) 279 } 280 if got, haveCL := res.Header["Content-Length"]; haveCL { 281 t.Errorf("Unexpected Content-Length: %q", got) 282 } 283 } 284 285 type reqFunc func(c *Client, url string) (*Response, error) 286 287 // h12Compare is a test that compares HTTP/1 and HTTP/2 behavior 288 // against each other. 289 type h12Compare struct { 290 Handler func(ResponseWriter, *Request) // required 291 ReqFunc reqFunc // optional 292 CheckResponse func(proto string, res *Response) // optional 293 EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize 294 Opts []any 295 } 296 297 func (tt h12Compare) reqFunc() reqFunc { 298 if tt.ReqFunc == nil { 299 return (*Client).Get 300 } 301 return tt.ReqFunc 302 } 303 304 func (tt h12Compare) run(t *testing.T) { 305 setParallel(t) 306 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...) 307 defer cst1.close() 308 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...) 309 defer cst2.close() 310 311 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) 312 if err != nil { 313 t.Errorf("HTTP/1 request: %v", err) 314 return 315 } 316 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL) 317 if err != nil { 318 t.Errorf("HTTP/2 request: %v", err) 319 return 320 } 321 322 if fn := tt.EarlyCheckResponse; fn != nil { 323 fn("HTTP/1.1", res1) 324 fn("HTTP/2.0", res2) 325 } 326 327 tt.normalizeRes(t, res1, "HTTP/1.1") 328 tt.normalizeRes(t, res2, "HTTP/2.0") 329 res1body, res2body := res1.Body, res2.Body 330 331 eres1 := mostlyCopy(res1) 332 eres2 := mostlyCopy(res2) 333 if !reflect.DeepEqual(eres1, eres2) { 334 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v", 335 cst1.ts.URL, eres1, cst2.ts.URL, eres2) 336 } 337 if !reflect.DeepEqual(res1body, res2body) { 338 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body) 339 } 340 if fn := tt.CheckResponse; fn != nil { 341 res1.Body, res2.Body = res1body, res2body 342 fn("HTTP/1.1", res1) 343 fn("HTTP/2.0", res2) 344 } 345 } 346 347 func mostlyCopy(r *Response) *Response { 348 c := *r 349 c.Body = nil 350 c.TransferEncoding = nil 351 c.TLS = nil 352 c.Request = nil 353 return &c 354 } 355 356 type slurpResult struct { 357 io.ReadCloser 358 body []byte 359 err error 360 } 361 362 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) } 363 364 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) { 365 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" { 366 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0 367 } else { 368 t.Errorf("got %q response; want %q", res.Proto, wantProto) 369 } 370 slurp, err := io.ReadAll(res.Body) 371 372 res.Body.Close() 373 res.Body = slurpResult{ 374 ReadCloser: io.NopCloser(bytes.NewReader(slurp)), 375 body: slurp, 376 err: err, 377 } 378 for i, v := range res.Header["Date"] { 379 res.Header["Date"][i] = strings.Repeat("x", len(v)) 380 } 381 if res.Request == nil { 382 t.Errorf("for %s, no request", wantProto) 383 } 384 if (res.TLS != nil) != (wantProto == "HTTP/2.0") { 385 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil) 386 } 387 } 388 389 // Issue 13532 390 func TestH12_HeadContentLengthNoBody(t *testing.T) { 391 h12Compare{ 392 ReqFunc: (*Client).Head, 393 Handler: func(w ResponseWriter, r *Request) { 394 }, 395 }.run(t) 396 } 397 398 func TestH12_HeadContentLengthSmallBody(t *testing.T) { 399 h12Compare{ 400 ReqFunc: (*Client).Head, 401 Handler: func(w ResponseWriter, r *Request) { 402 io.WriteString(w, "small") 403 }, 404 }.run(t) 405 } 406 407 func TestH12_HeadContentLengthLargeBody(t *testing.T) { 408 h12Compare{ 409 ReqFunc: (*Client).Head, 410 Handler: func(w ResponseWriter, r *Request) { 411 chunk := strings.Repeat("x", 512<<10) 412 for i := 0; i < 10; i++ { 413 io.WriteString(w, chunk) 414 } 415 }, 416 }.run(t) 417 } 418 419 func TestH12_200NoBody(t *testing.T) { 420 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t) 421 } 422 423 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) } 424 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) } 425 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) } 426 427 func testH12_noBody(t *testing.T, status int) { 428 h12Compare{Handler: func(w ResponseWriter, r *Request) { 429 w.WriteHeader(status) 430 }}.run(t) 431 } 432 433 func TestH12_SmallBody(t *testing.T) { 434 h12Compare{Handler: func(w ResponseWriter, r *Request) { 435 io.WriteString(w, "small body") 436 }}.run(t) 437 } 438 439 func TestH12_ExplicitContentLength(t *testing.T) { 440 h12Compare{Handler: func(w ResponseWriter, r *Request) { 441 w.Header().Set("Content-Length", "3") 442 io.WriteString(w, "foo") 443 }}.run(t) 444 } 445 446 func TestH12_FlushBeforeBody(t *testing.T) { 447 h12Compare{Handler: func(w ResponseWriter, r *Request) { 448 w.(Flusher).Flush() 449 io.WriteString(w, "foo") 450 }}.run(t) 451 } 452 453 func TestH12_FlushMidBody(t *testing.T) { 454 h12Compare{Handler: func(w ResponseWriter, r *Request) { 455 io.WriteString(w, "foo") 456 w.(Flusher).Flush() 457 io.WriteString(w, "bar") 458 }}.run(t) 459 } 460 461 func TestH12_Head_ExplicitLen(t *testing.T) { 462 h12Compare{ 463 ReqFunc: (*Client).Head, 464 Handler: func(w ResponseWriter, r *Request) { 465 if r.Method != "HEAD" { 466 t.Errorf("unexpected method %q", r.Method) 467 } 468 w.Header().Set("Content-Length", "1235") 469 }, 470 }.run(t) 471 } 472 473 func TestH12_Head_ImplicitLen(t *testing.T) { 474 h12Compare{ 475 ReqFunc: (*Client).Head, 476 Handler: func(w ResponseWriter, r *Request) { 477 if r.Method != "HEAD" { 478 t.Errorf("unexpected method %q", r.Method) 479 } 480 io.WriteString(w, "foo") 481 }, 482 }.run(t) 483 } 484 485 func TestH12_HandlerWritesTooLittle(t *testing.T) { 486 h12Compare{ 487 Handler: func(w ResponseWriter, r *Request) { 488 w.Header().Set("Content-Length", "3") 489 io.WriteString(w, "12") // one byte short 490 }, 491 CheckResponse: func(proto string, res *Response) { 492 sr, ok := res.Body.(slurpResult) 493 if !ok { 494 t.Errorf("%s body is %T; want slurpResult", proto, res.Body) 495 return 496 } 497 if sr.err != io.ErrUnexpectedEOF { 498 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err) 499 } 500 if string(sr.body) != "12" { 501 t.Errorf("%s body = %q; want %q", proto, sr.body, "12") 502 } 503 }, 504 }.run(t) 505 } 506 507 // Tests that the HTTP/1 and HTTP/2 servers prevent handlers from 508 // writing more than they declared. This test does not test whether 509 // the transport deals with too much data, though, since the server 510 // doesn't make it possible to send bogus data. For those tests, see 511 // transport_test.go (for HTTP/1) or x/net/http2/transport_test.go 512 // (for HTTP/2). 513 func TestH12_HandlerWritesTooMuch(t *testing.T) { 514 h12Compare{ 515 Handler: func(w ResponseWriter, r *Request) { 516 w.Header().Set("Content-Length", "3") 517 w.(Flusher).Flush() 518 io.WriteString(w, "123") 519 w.(Flusher).Flush() 520 n, err := io.WriteString(w, "x") // too many 521 if n > 0 || err == nil { 522 t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err) 523 } 524 }, 525 }.run(t) 526 } 527 528 // Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip. 529 // Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298 530 func TestH12_AutoGzip(t *testing.T) { 531 h12Compare{ 532 Handler: func(w ResponseWriter, r *Request) { 533 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" { 534 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae) 535 } 536 w.Header().Set("Content-Encoding", "gzip") 537 gz := gzip.NewWriter(w) 538 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.") 539 gz.Close() 540 }, 541 }.run(t) 542 } 543 544 func TestH12_AutoGzip_Disabled(t *testing.T) { 545 h12Compare{ 546 Opts: []any{ 547 func(tr *Transport) { tr.DisableCompression = true }, 548 }, 549 Handler: func(w ResponseWriter, r *Request) { 550 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"]) 551 if ae := r.Header.Get("Accept-Encoding"); ae != "" { 552 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae) 553 } 554 }, 555 }.run(t) 556 } 557 558 // Test304Responses verifies that 304s don't declare that they're 559 // chunking in their response headers and aren't allowed to produce 560 // output. 561 func Test304Responses(t *testing.T) { run(t, test304Responses) } 562 func test304Responses(t *testing.T, mode testMode) { 563 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 564 w.WriteHeader(StatusNotModified) 565 _, err := w.Write([]byte("illegal body")) 566 if err != ErrBodyNotAllowed { 567 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) 568 } 569 })) 570 defer cst.close() 571 res, err := cst.c.Get(cst.ts.URL) 572 if err != nil { 573 t.Fatal(err) 574 } 575 if len(res.TransferEncoding) > 0 { 576 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) 577 } 578 body, err := io.ReadAll(res.Body) 579 if err != nil { 580 t.Error(err) 581 } 582 if len(body) > 0 { 583 t.Errorf("got unexpected body %q", string(body)) 584 } 585 } 586 587 func TestH12_ServerEmptyContentLength(t *testing.T) { 588 h12Compare{ 589 Handler: func(w ResponseWriter, r *Request) { 590 w.Header()["Content-Type"] = []string{""} 591 io.WriteString(w, "<html><body>hi</body></html>") 592 }, 593 }.run(t) 594 } 595 596 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) { 597 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4) 598 } 599 600 func TestH12_RequestContentLength_Known_Zero(t *testing.T) { 601 h12requestContentLength(t, func() io.Reader { return nil }, 0) 602 } 603 604 func TestH12_RequestContentLength_Unknown(t *testing.T) { 605 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1) 606 } 607 608 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) { 609 h12Compare{ 610 Handler: func(w ResponseWriter, r *Request) { 611 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength)) 612 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength) 613 }, 614 ReqFunc: func(c *Client, url string) (*Response, error) { 615 return c.Post(url, "text/plain", bodyfn()) 616 }, 617 CheckResponse: func(proto string, res *Response) { 618 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want { 619 t.Errorf("Proto %q got length %q; want %q", proto, got, want) 620 } 621 }, 622 }.run(t) 623 } 624 625 // Tests that closing the Request.Cancel channel also while still 626 // reading the response body. Issue 13159. 627 func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) } 628 func testCancelRequestMidBody(t *testing.T, mode testMode) { 629 unblock := make(chan bool) 630 didFlush := make(chan bool, 1) 631 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 632 io.WriteString(w, "Hello") 633 w.(Flusher).Flush() 634 didFlush <- true 635 <-unblock 636 io.WriteString(w, ", world.") 637 })) 638 defer close(unblock) 639 640 req, _ := NewRequest("GET", cst.ts.URL, nil) 641 cancel := make(chan struct{}) 642 req.Cancel = cancel 643 644 res, err := cst.c.Do(req) 645 if err != nil { 646 t.Fatal(err) 647 } 648 defer res.Body.Close() 649 <-didFlush 650 651 // Read a bit before we cancel. (Issue 13626) 652 // We should have "Hello" at least sitting there. 653 firstRead := make([]byte, 10) 654 n, err := res.Body.Read(firstRead) 655 if err != nil { 656 t.Fatal(err) 657 } 658 firstRead = firstRead[:n] 659 660 close(cancel) 661 662 rest, err := io.ReadAll(res.Body) 663 all := string(firstRead) + string(rest) 664 if all != "Hello" { 665 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest) 666 } 667 if err != ExportErrRequestCanceled { 668 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled) 669 } 670 } 671 672 // Tests that clients can send trailers to a server and that the server can read them. 673 func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) } 674 func testTrailersClientToServer(t *testing.T, mode testMode) { 675 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 676 var decl []string 677 for k := range r.Trailer { 678 decl = append(decl, k) 679 } 680 sort.Strings(decl) 681 682 slurp, err := io.ReadAll(r.Body) 683 if err != nil { 684 t.Errorf("Server reading request body: %v", err) 685 } 686 if string(slurp) != "foo" { 687 t.Errorf("Server read request body %q; want foo", slurp) 688 } 689 if r.Trailer == nil { 690 io.WriteString(w, "nil Trailer") 691 } else { 692 fmt.Fprintf(w, "decl: %v, vals: %s, %s", 693 decl, 694 r.Trailer.Get("Client-Trailer-A"), 695 r.Trailer.Get("Client-Trailer-B")) 696 } 697 })) 698 699 var req *Request 700 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader( 701 eofReaderFunc(func() { 702 req.Trailer["Client-Trailer-A"] = []string{"valuea"} 703 }), 704 strings.NewReader("foo"), 705 eofReaderFunc(func() { 706 req.Trailer["Client-Trailer-B"] = []string{"valueb"} 707 }), 708 )) 709 req.Trailer = Header{ 710 "Client-Trailer-A": nil, // to be set later 711 "Client-Trailer-B": nil, // to be set later 712 } 713 req.ContentLength = -1 714 res, err := cst.c.Do(req) 715 if err != nil { 716 t.Fatal(err) 717 } 718 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil { 719 t.Error(err) 720 } 721 } 722 723 // Tests that servers send trailers to a client and that the client can read them. 724 func TestTrailersServerToClient(t *testing.T) { 725 run(t, func(t *testing.T, mode testMode) { 726 testTrailersServerToClient(t, mode, false) 727 }) 728 } 729 func TestTrailersServerToClientFlush(t *testing.T) { 730 run(t, func(t *testing.T, mode testMode) { 731 testTrailersServerToClient(t, mode, true) 732 }) 733 } 734 735 func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) { 736 const body = "Some body" 737 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 738 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") 739 w.Header().Add("Trailer", "Server-Trailer-C") 740 741 io.WriteString(w, body) 742 if flush { 743 w.(Flusher).Flush() 744 } 745 746 // How handlers set Trailers: declare it ahead of time 747 // with the Trailer header, and then mutate the 748 // Header() of those values later, after the response 749 // has been written (we wrote to w above). 750 w.Header().Set("Server-Trailer-A", "valuea") 751 w.Header().Set("Server-Trailer-C", "valuec") // skipping B 752 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted") 753 })) 754 755 res, err := cst.c.Get(cst.ts.URL) 756 if err != nil { 757 t.Fatal(err) 758 } 759 760 wantHeader := Header{ 761 "Content-Type": {"text/plain; charset=utf-8"}, 762 } 763 wantLen := -1 764 if mode == http2Mode && !flush { 765 // In HTTP/1.1, any use of trailers forces HTTP/1.1 766 // chunking and a flush at the first write. That's 767 // unnecessary with HTTP/2's framing, so the server 768 // is able to calculate the length while still sending 769 // trailers afterwards. 770 wantLen = len(body) 771 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)} 772 } 773 if res.ContentLength != int64(wantLen) { 774 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen) 775 } 776 777 delete(res.Header, "Date") // irrelevant for test 778 if !reflect.DeepEqual(res.Header, wantHeader) { 779 t.Errorf("Header = %v; want %v", res.Header, wantHeader) 780 } 781 782 if got, want := res.Trailer, (Header{ 783 "Server-Trailer-A": nil, 784 "Server-Trailer-B": nil, 785 "Server-Trailer-C": nil, 786 }); !reflect.DeepEqual(got, want) { 787 t.Errorf("Trailer before body read = %v; want %v", got, want) 788 } 789 790 if err := wantBody(res, nil, body); err != nil { 791 t.Fatal(err) 792 } 793 794 if got, want := res.Trailer, (Header{ 795 "Server-Trailer-A": {"valuea"}, 796 "Server-Trailer-B": nil, 797 "Server-Trailer-C": {"valuec"}, 798 }); !reflect.DeepEqual(got, want) { 799 t.Errorf("Trailer after body read = %v; want %v", got, want) 800 } 801 } 802 803 // Don't allow a Body.Read after Body.Close. Issue 13648. 804 func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) } 805 func testResponseBodyReadAfterClose(t *testing.T, mode testMode) { 806 const body = "Some body" 807 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 808 io.WriteString(w, body) 809 })) 810 res, err := cst.c.Get(cst.ts.URL) 811 if err != nil { 812 t.Fatal(err) 813 } 814 res.Body.Close() 815 data, err := io.ReadAll(res.Body) 816 if len(data) != 0 || err == nil { 817 t.Fatalf("ReadAll returned %q, %v; want error", data, err) 818 } 819 } 820 821 func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) } 822 func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) { 823 const reqBody = "some request body" 824 const resBody = "some response body" 825 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 826 var wg sync.WaitGroup 827 wg.Add(2) 828 didRead := make(chan bool, 1) 829 // Read in one goroutine. 830 go func() { 831 defer wg.Done() 832 data, err := io.ReadAll(r.Body) 833 if string(data) != reqBody { 834 t.Errorf("Handler read %q; want %q", data, reqBody) 835 } 836 if err != nil { 837 t.Errorf("Handler Read: %v", err) 838 } 839 didRead <- true 840 }() 841 // Write in another goroutine. 842 go func() { 843 defer wg.Done() 844 if mode != http2Mode { 845 // our HTTP/1 implementation intentionally 846 // doesn't permit writes during read (mostly 847 // due to it being undefined); if that is ever 848 // relaxed, change this. 849 <-didRead 850 } 851 io.WriteString(w, resBody) 852 }() 853 wg.Wait() 854 })) 855 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody)) 856 req.Header.Add("Expect", "100-continue") // just to complicate things 857 res, err := cst.c.Do(req) 858 if err != nil { 859 t.Fatal(err) 860 } 861 data, err := io.ReadAll(res.Body) 862 defer res.Body.Close() 863 if err != nil { 864 t.Fatal(err) 865 } 866 if string(data) != resBody { 867 t.Errorf("read %q; want %q", data, resBody) 868 } 869 } 870 871 func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) } 872 func testConnectRequest(t *testing.T, mode testMode) { 873 gotc := make(chan *Request, 1) 874 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 875 gotc <- r 876 })) 877 878 u, err := url.Parse(cst.ts.URL) 879 if err != nil { 880 t.Fatal(err) 881 } 882 883 tests := []struct { 884 req *Request 885 want string 886 }{ 887 { 888 req: &Request{ 889 Method: "CONNECT", 890 Header: Header{}, 891 URL: u, 892 }, 893 want: u.Host, 894 }, 895 { 896 req: &Request{ 897 Method: "CONNECT", 898 Header: Header{}, 899 URL: u, 900 Host: "example.com:123", 901 }, 902 want: "example.com:123", 903 }, 904 } 905 906 for i, tt := range tests { 907 res, err := cst.c.Do(tt.req) 908 if err != nil { 909 t.Errorf("%d. RoundTrip = %v", i, err) 910 continue 911 } 912 res.Body.Close() 913 req := <-gotc 914 if req.Method != "CONNECT" { 915 t.Errorf("method = %q; want CONNECT", req.Method) 916 } 917 if req.Host != tt.want { 918 t.Errorf("Host = %q; want %q", req.Host, tt.want) 919 } 920 if req.URL.Host != tt.want { 921 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want) 922 } 923 } 924 } 925 926 func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) } 927 func testTransportUserAgent(t *testing.T, mode testMode) { 928 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 929 fmt.Fprintf(w, "%q", r.Header["User-Agent"]) 930 })) 931 932 either := func(a, b string) string { 933 if mode == http2Mode { 934 return b 935 } 936 return a 937 } 938 939 tests := []struct { 940 setup func(*Request) 941 want string 942 }{ 943 { 944 func(r *Request) {}, 945 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`), 946 }, 947 { 948 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") }, 949 `["foo/1.2.3"]`, 950 }, 951 { 952 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} }, 953 `["single"]`, 954 }, 955 { 956 func(r *Request) { r.Header.Set("User-Agent", "") }, 957 `[]`, 958 }, 959 { 960 func(r *Request) { r.Header["User-Agent"] = nil }, 961 `[]`, 962 }, 963 } 964 for i, tt := range tests { 965 req, _ := NewRequest("GET", cst.ts.URL, nil) 966 tt.setup(req) 967 res, err := cst.c.Do(req) 968 if err != nil { 969 t.Errorf("%d. RoundTrip = %v", i, err) 970 continue 971 } 972 slurp, err := io.ReadAll(res.Body) 973 res.Body.Close() 974 if err != nil { 975 t.Errorf("%d. read body = %v", i, err) 976 continue 977 } 978 if string(slurp) != tt.want { 979 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want) 980 } 981 } 982 } 983 984 func TestStarRequestMethod(t *testing.T) { 985 for _, method := range []string{"FOO", "OPTIONS"} { 986 t.Run(method, func(t *testing.T) { 987 run(t, func(t *testing.T, mode testMode) { 988 testStarRequest(t, method, mode) 989 }) 990 }) 991 } 992 } 993 func testStarRequest(t *testing.T, method string, mode testMode) { 994 gotc := make(chan *Request, 1) 995 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 996 w.Header().Set("foo", "bar") 997 gotc <- r 998 w.(Flusher).Flush() 999 })) 1000 1001 u, err := url.Parse(cst.ts.URL) 1002 if err != nil { 1003 t.Fatal(err) 1004 } 1005 u.Path = "*" 1006 1007 req := &Request{ 1008 Method: method, 1009 Header: Header{}, 1010 URL: u, 1011 } 1012 1013 res, err := cst.c.Do(req) 1014 if err != nil { 1015 t.Fatalf("RoundTrip = %v", err) 1016 } 1017 res.Body.Close() 1018 1019 wantFoo := "bar" 1020 wantLen := int64(-1) 1021 if method == "OPTIONS" { 1022 wantFoo = "" 1023 wantLen = 0 1024 } 1025 if res.StatusCode != 200 { 1026 t.Errorf("status code = %v; want %d", res.Status, 200) 1027 } 1028 if res.ContentLength != wantLen { 1029 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen) 1030 } 1031 if got := res.Header.Get("foo"); got != wantFoo { 1032 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo) 1033 } 1034 select { 1035 case req = <-gotc: 1036 default: 1037 req = nil 1038 } 1039 if req == nil { 1040 if method != "OPTIONS" { 1041 t.Fatalf("handler never got request") 1042 } 1043 return 1044 } 1045 if req.Method != method { 1046 t.Errorf("method = %q; want %q", req.Method, method) 1047 } 1048 if req.URL.Path != "*" { 1049 t.Errorf("URL.Path = %q; want *", req.URL.Path) 1050 } 1051 if req.RequestURI != "*" { 1052 t.Errorf("RequestURI = %q; want *", req.RequestURI) 1053 } 1054 } 1055 1056 // Issue 13957 1057 func TestTransportDiscardsUnneededConns(t *testing.T) { 1058 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode}) 1059 } 1060 func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) { 1061 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1062 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) 1063 })) 1064 defer cst.close() 1065 1066 var numOpen, numClose int32 // atomic 1067 1068 tlsConfig := &tls.Config{InsecureSkipVerify: true} 1069 tr := &Transport{ 1070 TLSClientConfig: tlsConfig, 1071 DialTLS: func(_, addr string) (net.Conn, error) { 1072 time.Sleep(10 * time.Millisecond) 1073 rc, err := net.Dial("tcp", addr) 1074 if err != nil { 1075 return nil, err 1076 } 1077 atomic.AddInt32(&numOpen, 1) 1078 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }} 1079 return tls.Client(c, tlsConfig), nil 1080 }, 1081 } 1082 if err := ExportHttp2ConfigureTransport(tr); err != nil { 1083 t.Fatal(err) 1084 } 1085 defer tr.CloseIdleConnections() 1086 1087 c := &Client{Transport: tr} 1088 1089 const N = 10 1090 gotBody := make(chan string, N) 1091 var wg sync.WaitGroup 1092 for i := 0; i < N; i++ { 1093 wg.Add(1) 1094 go func() { 1095 defer wg.Done() 1096 resp, err := c.Get(cst.ts.URL) 1097 if err != nil { 1098 // Try to work around spurious connection reset on loaded system. 1099 // See golang.org/issue/33585 and golang.org/issue/36797. 1100 time.Sleep(10 * time.Millisecond) 1101 resp, err = c.Get(cst.ts.URL) 1102 if err != nil { 1103 t.Errorf("Get: %v", err) 1104 return 1105 } 1106 } 1107 defer resp.Body.Close() 1108 slurp, err := io.ReadAll(resp.Body) 1109 if err != nil { 1110 t.Error(err) 1111 } 1112 gotBody <- string(slurp) 1113 }() 1114 } 1115 wg.Wait() 1116 close(gotBody) 1117 1118 var last string 1119 for got := range gotBody { 1120 if last == "" { 1121 last = got 1122 continue 1123 } 1124 if got != last { 1125 t.Errorf("Response body changed: %q -> %q", last, got) 1126 } 1127 } 1128 1129 var open, close int32 1130 for i := 0; i < 150; i++ { 1131 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose) 1132 if open < 1 { 1133 t.Fatalf("open = %d; want at least", open) 1134 } 1135 if close == open-1 { 1136 // Success 1137 return 1138 } 1139 time.Sleep(10 * time.Millisecond) 1140 } 1141 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1) 1142 } 1143 1144 // tests that Transport doesn't retain a pointer to the provided request. 1145 func TestTransportGCRequest(t *testing.T) { 1146 run(t, func(t *testing.T, mode testMode) { 1147 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) }) 1148 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) }) 1149 }) 1150 } 1151 func testTransportGCRequest(t *testing.T, mode testMode, body bool) { 1152 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1153 io.ReadAll(r.Body) 1154 if body { 1155 io.WriteString(w, "Hello.") 1156 } 1157 })) 1158 1159 didGC := make(chan struct{}) 1160 (func() { 1161 body := strings.NewReader("some body") 1162 req, _ := NewRequest("POST", cst.ts.URL, body) 1163 runtime.SetFinalizer(req, func(*Request) { close(didGC) }) 1164 res, err := cst.c.Do(req) 1165 if err != nil { 1166 t.Fatal(err) 1167 } 1168 if _, err := io.ReadAll(res.Body); err != nil { 1169 t.Fatal(err) 1170 } 1171 if err := res.Body.Close(); err != nil { 1172 t.Fatal(err) 1173 } 1174 })() 1175 for { 1176 select { 1177 case <-didGC: 1178 return 1179 case <-time.After(1 * time.Millisecond): 1180 runtime.GC() 1181 } 1182 } 1183 } 1184 1185 func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) } 1186 func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) { 1187 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1188 fmt.Fprintf(w, "Handler saw headers: %q", r.Header) 1189 }), optQuietLog) 1190 cst.tr.DisableKeepAlives = true 1191 1192 tests := []struct { 1193 key, val string 1194 ok bool 1195 }{ 1196 {"Foo", "capital-key", true}, // verify h2 allows capital keys 1197 {"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed 1198 {"Foo", "two\nlines", false}, // \n byte in value not allowed 1199 {"bogus\nkey", "v", false}, // \n byte also not allowed in key 1200 {"A space", "v", false}, // spaces in keys not allowed 1201 {"имя", "v", false}, // key must be ascii 1202 {"name", "валю", true}, // value may be non-ascii 1203 {"", "v", false}, // key must be non-empty 1204 {"k", "", true}, // value may be empty 1205 } 1206 for _, tt := range tests { 1207 dialedc := make(chan bool, 1) 1208 cst.tr.Dial = func(netw, addr string) (net.Conn, error) { 1209 dialedc <- true 1210 return net.Dial(netw, addr) 1211 } 1212 req, _ := NewRequest("GET", cst.ts.URL, nil) 1213 req.Header[tt.key] = []string{tt.val} 1214 res, err := cst.c.Do(req) 1215 var body []byte 1216 if err == nil { 1217 body, _ = io.ReadAll(res.Body) 1218 res.Body.Close() 1219 } 1220 var dialed bool 1221 select { 1222 case <-dialedc: 1223 dialed = true 1224 default: 1225 } 1226 1227 if !tt.ok && dialed { 1228 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body) 1229 } else if (err == nil) != tt.ok { 1230 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok) 1231 } 1232 } 1233 } 1234 1235 func TestInterruptWithPanic(t *testing.T) { 1236 run(t, func(t *testing.T, mode testMode) { 1237 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") }) 1238 t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) }) 1239 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) }) 1240 }, testNotParallel) 1241 } 1242 func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) { 1243 const msg = "hello" 1244 1245 testDone := make(chan struct{}) 1246 defer close(testDone) 1247 1248 var errorLog lockedBytesBuffer 1249 gotHeaders := make(chan bool, 1) 1250 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1251 io.WriteString(w, msg) 1252 w.(Flusher).Flush() 1253 1254 select { 1255 case <-gotHeaders: 1256 case <-testDone: 1257 } 1258 panic(panicValue) 1259 }), func(ts *httptest.Server) { 1260 ts.Config.ErrorLog = log.New(&errorLog, "", 0) 1261 }) 1262 res, err := cst.c.Get(cst.ts.URL) 1263 if err != nil { 1264 t.Fatal(err) 1265 } 1266 gotHeaders <- true 1267 defer res.Body.Close() 1268 slurp, err := io.ReadAll(res.Body) 1269 if string(slurp) != msg { 1270 t.Errorf("client read %q; want %q", slurp, msg) 1271 } 1272 if err == nil { 1273 t.Errorf("client read all successfully; want some error") 1274 } 1275 logOutput := func() string { 1276 errorLog.Lock() 1277 defer errorLog.Unlock() 1278 return errorLog.String() 1279 } 1280 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler 1281 1282 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { 1283 gotLog := logOutput() 1284 if !wantStackLogged { 1285 if gotLog == "" { 1286 return true 1287 } 1288 t.Fatalf("want no log output; got: %s", gotLog) 1289 } 1290 if gotLog == "" { 1291 if d > 0 { 1292 t.Logf("wanted a stack trace logged; got nothing after %v", d) 1293 } 1294 return false 1295 } 1296 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 { 1297 if d > 0 { 1298 t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog) 1299 } 1300 return false 1301 } 1302 return true 1303 }) 1304 } 1305 1306 type lockedBytesBuffer struct { 1307 sync.Mutex 1308 bytes.Buffer 1309 } 1310 1311 func (b *lockedBytesBuffer) Write(p []byte) (int, error) { 1312 b.Lock() 1313 defer b.Unlock() 1314 return b.Buffer.Write(p) 1315 } 1316 1317 // Issue 15366 1318 func TestH12_AutoGzipWithDumpResponse(t *testing.T) { 1319 h12Compare{ 1320 Handler: func(w ResponseWriter, r *Request) { 1321 h := w.Header() 1322 h.Set("Content-Encoding", "gzip") 1323 h.Set("Content-Length", "23") 1324 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00") 1325 }, 1326 EarlyCheckResponse: func(proto string, res *Response) { 1327 if !res.Uncompressed { 1328 t.Errorf("%s: expected Uncompressed to be set", proto) 1329 } 1330 dump, err := httputil.DumpResponse(res, true) 1331 if err != nil { 1332 t.Errorf("%s: DumpResponse: %v", proto, err) 1333 return 1334 } 1335 if strings.Contains(string(dump), "Connection: close") { 1336 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump) 1337 } 1338 if !strings.Contains(string(dump), "FOO") { 1339 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump) 1340 } 1341 }, 1342 }.run(t) 1343 } 1344 1345 // Issue 14607 1346 func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) } 1347 func testCloseIdleConnections(t *testing.T, mode testMode) { 1348 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1349 w.Header().Set("X-Addr", r.RemoteAddr) 1350 })) 1351 get := func() string { 1352 res, err := cst.c.Get(cst.ts.URL) 1353 if err != nil { 1354 t.Fatal(err) 1355 } 1356 res.Body.Close() 1357 v := res.Header.Get("X-Addr") 1358 if v == "" { 1359 t.Fatal("didn't get X-Addr") 1360 } 1361 return v 1362 } 1363 a1 := get() 1364 cst.tr.CloseIdleConnections() 1365 a2 := get() 1366 if a1 == a2 { 1367 t.Errorf("didn't close connection") 1368 } 1369 } 1370 1371 type noteCloseConn struct { 1372 net.Conn 1373 closeFunc func() 1374 } 1375 1376 func (x noteCloseConn) Close() error { 1377 x.closeFunc() 1378 return x.Conn.Close() 1379 } 1380 1381 type testErrorReader struct{ t *testing.T } 1382 1383 func (r testErrorReader) Read(p []byte) (n int, err error) { 1384 r.t.Error("unexpected Read call") 1385 return 0, io.EOF 1386 } 1387 1388 func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) } 1389 func testNoSniffExpectRequestBody(t *testing.T, mode testMode) { 1390 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1391 w.WriteHeader(StatusUnauthorized) 1392 })) 1393 1394 // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. 1395 cst.tr.ExpectContinueTimeout = 10 * time.Second 1396 1397 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t}) 1398 if err != nil { 1399 t.Fatal(err) 1400 } 1401 req.ContentLength = 0 // so transport is tempted to sniff it 1402 req.Header.Set("Expect", "100-continue") 1403 res, err := cst.tr.RoundTrip(req) 1404 if err != nil { 1405 t.Fatal(err) 1406 } 1407 defer res.Body.Close() 1408 if res.StatusCode != StatusUnauthorized { 1409 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized) 1410 } 1411 } 1412 1413 func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) } 1414 func testServerUndeclaredTrailers(t *testing.T, mode testMode) { 1415 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1416 w.Header().Set("Foo", "Bar") 1417 w.Header().Set("Trailer:Foo", "Baz") 1418 w.(Flusher).Flush() 1419 w.Header().Add("Trailer:Foo", "Baz2") 1420 w.Header().Set("Trailer:Bar", "Quux") 1421 })) 1422 res, err := cst.c.Get(cst.ts.URL) 1423 if err != nil { 1424 t.Fatal(err) 1425 } 1426 if _, err := io.Copy(io.Discard, res.Body); err != nil { 1427 t.Fatal(err) 1428 } 1429 res.Body.Close() 1430 delete(res.Header, "Date") 1431 delete(res.Header, "Content-Type") 1432 1433 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) { 1434 t.Errorf("Header = %#v; want %#v", res.Header, want) 1435 } 1436 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) { 1437 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want) 1438 } 1439 } 1440 1441 func TestBadResponseAfterReadingBody(t *testing.T) { 1442 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode}) 1443 } 1444 func testBadResponseAfterReadingBody(t *testing.T, mode testMode) { 1445 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1446 _, err := io.Copy(io.Discard, r.Body) 1447 if err != nil { 1448 t.Fatal(err) 1449 } 1450 c, _, err := w.(Hijacker).Hijack() 1451 if err != nil { 1452 t.Fatal(err) 1453 } 1454 defer c.Close() 1455 fmt.Fprintln(c, "some bogus crap") 1456 })) 1457 1458 closes := 0 1459 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) 1460 if err == nil { 1461 res.Body.Close() 1462 t.Fatal("expected an error to be returned from Post") 1463 } 1464 if closes != 1 { 1465 t.Errorf("closes = %d; want 1", closes) 1466 } 1467 } 1468 1469 func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) } 1470 func testWriteHeader0(t *testing.T, mode testMode) { 1471 gotpanic := make(chan bool, 1) 1472 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1473 defer close(gotpanic) 1474 defer func() { 1475 if e := recover(); e != nil { 1476 got := fmt.Sprintf("%T, %v", e, e) 1477 want := "string, invalid WriteHeader code 0" 1478 if got != want { 1479 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want) 1480 } 1481 gotpanic <- true 1482 1483 // Set an explicit 503. This also tests that the WriteHeader call panics 1484 // before it recorded that an explicit value was set and that bogus 1485 // value wasn't stuck. 1486 w.WriteHeader(503) 1487 } 1488 }() 1489 w.WriteHeader(0) 1490 })) 1491 res, err := cst.c.Get(cst.ts.URL) 1492 if err != nil { 1493 t.Fatal(err) 1494 } 1495 if res.StatusCode != 503 { 1496 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status) 1497 } 1498 if !<-gotpanic { 1499 t.Error("expected panic in handler") 1500 } 1501 } 1502 1503 // Issue 23010: don't be super strict checking WriteHeader's code if 1504 // it's not even valid to call WriteHeader then anyway. 1505 func TestWriteHeaderNoCodeCheck(t *testing.T) { 1506 run(t, func(t *testing.T, mode testMode) { 1507 testWriteHeaderAfterWrite(t, mode, false) 1508 }) 1509 } 1510 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { 1511 testWriteHeaderAfterWrite(t, http1Mode, true) 1512 } 1513 func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) { 1514 var errorLog lockedBytesBuffer 1515 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1516 if hijack { 1517 conn, _, _ := w.(Hijacker).Hijack() 1518 defer conn.Close() 1519 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo")) 1520 w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 1521 conn.Write([]byte("bar")) 1522 return 1523 } 1524 io.WriteString(w, "foo") 1525 w.(Flusher).Flush() 1526 w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 1527 io.WriteString(w, "bar") 1528 }), func(ts *httptest.Server) { 1529 ts.Config.ErrorLog = log.New(&errorLog, "", 0) 1530 }) 1531 res, err := cst.c.Get(cst.ts.URL) 1532 if err != nil { 1533 t.Fatal(err) 1534 } 1535 defer res.Body.Close() 1536 body, err := io.ReadAll(res.Body) 1537 if err != nil { 1538 t.Fatal(err) 1539 } 1540 if got, want := string(body), "foobar"; got != want { 1541 t.Errorf("got = %q; want %q", got, want) 1542 } 1543 1544 // Also check the stderr output: 1545 if mode == http2Mode { 1546 // TODO: also emit this log message for HTTP/2? 1547 // We historically haven't, so don't check. 1548 return 1549 } 1550 gotLog := strings.TrimSpace(errorLog.String()) 1551 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:" 1552 if hijack { 1553 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:" 1554 } 1555 if !strings.HasPrefix(gotLog, wantLog) { 1556 t.Errorf("stderr output = %q; want %q", gotLog, wantLog) 1557 } 1558 } 1559 1560 func TestBidiStreamReverseProxy(t *testing.T) { 1561 run(t, testBidiStreamReverseProxy, []testMode{http2Mode}) 1562 } 1563 func testBidiStreamReverseProxy(t *testing.T, mode testMode) { 1564 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1565 if _, err := io.Copy(w, r.Body); err != nil { 1566 log.Printf("bidi backend copy: %v", err) 1567 } 1568 })) 1569 1570 backURL, err := url.Parse(backend.ts.URL) 1571 if err != nil { 1572 t.Fatal(err) 1573 } 1574 rp := httputil.NewSingleHostReverseProxy(backURL) 1575 rp.Transport = backend.tr 1576 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1577 rp.ServeHTTP(w, r) 1578 })) 1579 1580 bodyRes := make(chan any, 1) // error or hash.Hash 1581 pr, pw := io.Pipe() 1582 req, _ := NewRequest("PUT", proxy.ts.URL, pr) 1583 const size = 4 << 20 1584 go func() { 1585 h := sha1.New() 1586 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size) 1587 go pw.Close() 1588 if err != nil { 1589 bodyRes <- err 1590 } else { 1591 bodyRes <- h 1592 } 1593 }() 1594 res, err := backend.c.Do(req) 1595 if err != nil { 1596 t.Fatal(err) 1597 } 1598 defer res.Body.Close() 1599 hgot := sha1.New() 1600 n, err := io.Copy(hgot, res.Body) 1601 if err != nil { 1602 t.Fatal(err) 1603 } 1604 if n != size { 1605 t.Fatalf("got %d bytes; want %d", n, size) 1606 } 1607 select { 1608 case v := <-bodyRes: 1609 switch v := v.(type) { 1610 default: 1611 t.Fatalf("body copy: %v", err) 1612 case hash.Hash: 1613 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) { 1614 t.Errorf("written bytes didn't match received bytes") 1615 } 1616 } 1617 case <-time.After(10 * time.Second): 1618 t.Fatal("timeout") 1619 } 1620 1621 } 1622 1623 // Always use HTTP/1.1 for WebSocket upgrades. 1624 func TestH12_WebSocketUpgrade(t *testing.T) { 1625 h12Compare{ 1626 Handler: func(w ResponseWriter, r *Request) { 1627 h := w.Header() 1628 h.Set("Foo", "bar") 1629 }, 1630 ReqFunc: func(c *Client, url string) (*Response, error) { 1631 req, _ := NewRequest("GET", url, nil) 1632 req.Header.Set("Connection", "Upgrade") 1633 req.Header.Set("Upgrade", "WebSocket") 1634 return c.Do(req) 1635 }, 1636 EarlyCheckResponse: func(proto string, res *Response) { 1637 if res.Proto != "HTTP/1.1" { 1638 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto) 1639 } 1640 res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0 1641 }, 1642 }.run(t) 1643 } 1644 1645 func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) } 1646 func testIdentityTransferEncoding(t *testing.T, mode testMode) { 1647 const body = "body" 1648 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1649 gotBody, _ := io.ReadAll(r.Body) 1650 if got, want := string(gotBody), body; got != want { 1651 t.Errorf("got request body = %q; want %q", got, want) 1652 } 1653 w.Header().Set("Transfer-Encoding", "identity") 1654 w.WriteHeader(StatusOK) 1655 w.(Flusher).Flush() 1656 io.WriteString(w, body) 1657 })) 1658 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body)) 1659 res, err := cst.c.Do(req) 1660 if err != nil { 1661 t.Fatal(err) 1662 } 1663 defer res.Body.Close() 1664 gotBody, err := io.ReadAll(res.Body) 1665 if err != nil { 1666 t.Fatal(err) 1667 } 1668 if got, want := string(gotBody), body; got != want { 1669 t.Errorf("got response body = %q; want %q", got, want) 1670 } 1671 } 1672 1673 func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) } 1674 func testEarlyHintsRequest(t *testing.T, mode testMode) { 1675 var wg sync.WaitGroup 1676 wg.Add(1) 1677 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1678 h := w.Header() 1679 1680 h.Add("Content-Length", "123") // must be ignored 1681 h.Add("Link", "</style.css>; rel=preload; as=style") 1682 h.Add("Link", "</script.js>; rel=preload; as=script") 1683 w.WriteHeader(StatusEarlyHints) 1684 1685 wg.Wait() 1686 1687 h.Add("Link", "</foo.js>; rel=preload; as=script") 1688 w.WriteHeader(StatusEarlyHints) 1689 1690 w.Write([]byte("Hello")) 1691 })) 1692 1693 checkLinkHeaders := func(t *testing.T, expected, got []string) { 1694 t.Helper() 1695 1696 if len(expected) != len(got) { 1697 t.Errorf("got %d expected %d", len(got), len(expected)) 1698 } 1699 1700 for i := range expected { 1701 if expected[i] != got[i] { 1702 t.Errorf("got %q expected %q", got[i], expected[i]) 1703 } 1704 } 1705 } 1706 1707 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) { 1708 t.Helper() 1709 1710 for _, h := range []string{"Content-Length", "Transfer-Encoding"} { 1711 if v, ok := header[h]; ok { 1712 t.Errorf("%s is %q; must not be sent", h, v) 1713 } 1714 } 1715 } 1716 1717 var respCounter uint8 1718 trace := &httptrace.ClientTrace{ 1719 Got1xxResponse: func(code int, header textproto.MIMEHeader) error { 1720 switch respCounter { 1721 case 0: 1722 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"]) 1723 checkExcludedHeaders(t, header) 1724 1725 wg.Done() 1726 case 1: 1727 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"]) 1728 checkExcludedHeaders(t, header) 1729 1730 default: 1731 t.Error("Unexpected 1xx response") 1732 } 1733 1734 respCounter++ 1735 1736 return nil 1737 }, 1738 } 1739 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil) 1740 1741 res, err := cst.c.Do(req) 1742 if err != nil { 1743 t.Fatal(err) 1744 } 1745 defer res.Body.Close() 1746 1747 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"]) 1748 if cl := res.Header.Get("Content-Length"); cl != "123" { 1749 t.Errorf("Content-Length is %q; want 123", cl) 1750 } 1751 1752 body, _ := io.ReadAll(res.Body) 1753 if string(body) != "Hello" { 1754 t.Errorf("Read body %q; want Hello", body) 1755 } 1756 }