github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/net/http/clientserver_test.go (about) 1 // Copyright 2015 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode. 6 7 package http_test 8 9 import ( 10 "bytes" 11 "compress/gzip" 12 "context" 13 "crypto/rand" 14 "crypto/sha1" 15 "crypto/tls" 16 "fmt" 17 "hash" 18 "io" 19 "log" 20 "net" 21 . "net/http" 22 "net/http/httptest" 23 "net/http/httptrace" 24 "net/http/httputil" 25 "net/textproto" 26 "net/url" 27 "os" 28 "reflect" 29 "runtime" 30 "sort" 31 "strings" 32 "sync" 33 "sync/atomic" 34 "testing" 35 "time" 36 ) 37 38 type testMode string 39 40 const ( 41 http1Mode = testMode("h1") // HTTP/1.1 42 https1Mode = testMode("https1") // HTTPS/1.1 43 http2Mode = testMode("h2") // HTTP/2 44 ) 45 46 type testNotParallelOpt struct{} 47 48 var ( 49 testNotParallel = testNotParallelOpt{} 50 ) 51 52 type TBRun[T any] interface { 53 testing.TB 54 Run(string, func(T)) bool 55 } 56 57 // run runs a client/server test in a variety of test configurations. 58 // 59 // Tests execute in HTTP/1.1 and HTTP/2 modes by default. 60 // To run in a different set of configurations, pass a []testMode option. 61 // 62 // Tests call t.Parallel() by default. 63 // To disable parallel execution, pass the testNotParallel option. 64 func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) { 65 t.Helper() 66 modes := []testMode{http1Mode, http2Mode} 67 parallel := true 68 for _, opt := range opts { 69 switch opt := opt.(type) { 70 case []testMode: 71 modes = opt 72 case testNotParallelOpt: 73 parallel = false 74 default: 75 t.Fatalf("unknown option type %T", opt) 76 } 77 } 78 if t, ok := any(t).(*testing.T); ok && parallel { 79 setParallel(t) 80 } 81 for _, mode := range modes { 82 t.Run(string(mode), func(t T) { 83 t.Helper() 84 if t, ok := any(t).(*testing.T); ok && parallel { 85 setParallel(t) 86 } 87 t.Cleanup(func() { 88 afterTest(t) 89 }) 90 f(t, mode) 91 }) 92 } 93 } 94 95 type clientServerTest struct { 96 t testing.TB 97 h2 bool 98 h Handler 99 ts *httptest.Server 100 tr *Transport 101 c *Client 102 } 103 104 func (t *clientServerTest) close() { 105 t.tr.CloseIdleConnections() 106 t.ts.Close() 107 } 108 109 func (t *clientServerTest) getURL(u string) string { 110 res, err := t.c.Get(u) 111 if err != nil { 112 t.t.Fatal(err) 113 } 114 defer res.Body.Close() 115 slurp, err := io.ReadAll(res.Body) 116 if err != nil { 117 t.t.Fatal(err) 118 } 119 return string(slurp) 120 } 121 122 func (t *clientServerTest) scheme() string { 123 if t.h2 { 124 return "https" 125 } 126 return "http" 127 } 128 129 var optQuietLog = func(ts *httptest.Server) { 130 ts.Config.ErrorLog = quietLog 131 } 132 133 func optWithServerLog(lg *log.Logger) func(*httptest.Server) { 134 return func(ts *httptest.Server) { 135 ts.Config.ErrorLog = lg 136 } 137 } 138 139 // newClientServerTest creates and starts an httptest.Server. 140 // 141 // The mode parameter selects the implementation to test: 142 // HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use 143 // the 'run' function, which will start a subtests for each tested mode. 144 // 145 // The vararg opts parameter can include functions to configure the 146 // test server or transport. 147 // 148 // func(*httptest.Server) // run before starting the server 149 // func(*http.Transport) 150 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest { 151 if mode == http2Mode { 152 CondSkipHTTP2(t) 153 } 154 cst := &clientServerTest{ 155 t: t, 156 h2: mode == http2Mode, 157 h: h, 158 } 159 cst.ts = httptest.NewUnstartedServer(h) 160 161 var transportFuncs []func(*Transport) 162 for _, opt := range opts { 163 switch opt := opt.(type) { 164 case func(*Transport): 165 transportFuncs = append(transportFuncs, opt) 166 case func(*httptest.Server): 167 opt(cst.ts) 168 default: 169 t.Fatalf("unhandled option type %T", opt) 170 } 171 } 172 173 if cst.ts.Config.ErrorLog == nil { 174 cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0) 175 } 176 177 switch mode { 178 case http1Mode: 179 cst.ts.Start() 180 case https1Mode: 181 cst.ts.StartTLS() 182 case http2Mode: 183 ExportHttp2ConfigureServer(cst.ts.Config, nil) 184 cst.ts.TLS = cst.ts.Config.TLSConfig 185 cst.ts.StartTLS() 186 default: 187 t.Fatalf("unknown test mode %v", mode) 188 } 189 cst.c = cst.ts.Client() 190 cst.tr = cst.c.Transport.(*Transport) 191 if mode == http2Mode { 192 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { 193 t.Fatal(err) 194 } 195 } 196 for _, f := range transportFuncs { 197 f(cst.tr) 198 } 199 t.Cleanup(func() { 200 cst.close() 201 }) 202 return cst 203 } 204 205 type testLogWriter struct { 206 t testing.TB 207 } 208 209 func (w testLogWriter) Write(b []byte) (int, error) { 210 w.t.Logf("server log: %v", strings.TrimSpace(string(b))) 211 return len(b), nil 212 } 213 214 // Testing the newClientServerTest helper itself. 215 func TestNewClientServerTest(t *testing.T) { 216 run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode}) 217 } 218 func testNewClientServerTest(t *testing.T, mode testMode) { 219 var got struct { 220 sync.Mutex 221 proto string 222 hasTLS bool 223 } 224 h := HandlerFunc(func(w ResponseWriter, r *Request) { 225 got.Lock() 226 defer got.Unlock() 227 got.proto = r.Proto 228 got.hasTLS = r.TLS != nil 229 }) 230 cst := newClientServerTest(t, mode, h) 231 if _, err := cst.c.Head(cst.ts.URL); err != nil { 232 t.Fatal(err) 233 } 234 var wantProto string 235 var wantTLS bool 236 switch mode { 237 case http1Mode: 238 wantProto = "HTTP/1.1" 239 wantTLS = false 240 case https1Mode: 241 wantProto = "HTTP/1.1" 242 wantTLS = true 243 case http2Mode: 244 wantProto = "HTTP/2.0" 245 wantTLS = true 246 } 247 if got.proto != wantProto { 248 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto) 249 } 250 if got.hasTLS != wantTLS { 251 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS) 252 } 253 } 254 255 func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) } 256 func testChunkedResponseHeaders(t *testing.T, mode testMode) { 257 log.SetOutput(io.Discard) // is noisy otherwise 258 defer log.SetOutput(os.Stderr) 259 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 260 w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted 261 w.(Flusher).Flush() 262 fmt.Fprintf(w, "I am a chunked response.") 263 })) 264 265 res, err := cst.c.Get(cst.ts.URL) 266 if err != nil { 267 t.Fatalf("Get error: %v", err) 268 } 269 defer res.Body.Close() 270 if g, e := res.ContentLength, int64(-1); g != e { 271 t.Errorf("expected ContentLength of %d; got %d", e, g) 272 } 273 wantTE := []string{"chunked"} 274 if mode == http2Mode { 275 wantTE = nil 276 } 277 if !reflect.DeepEqual(res.TransferEncoding, wantTE) { 278 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE) 279 } 280 if got, haveCL := res.Header["Content-Length"]; haveCL { 281 t.Errorf("Unexpected Content-Length: %q", got) 282 } 283 } 284 285 type reqFunc func(c *Client, url string) (*Response, error) 286 287 // h12Compare is a test that compares HTTP/1 and HTTP/2 behavior 288 // against each other. 289 type h12Compare struct { 290 Handler func(ResponseWriter, *Request) // required 291 ReqFunc reqFunc // optional 292 CheckResponse func(proto string, res *Response) // optional 293 EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize 294 Opts []any 295 } 296 297 func (tt h12Compare) reqFunc() reqFunc { 298 if tt.ReqFunc == nil { 299 return (*Client).Get 300 } 301 return tt.ReqFunc 302 } 303 304 func (tt h12Compare) run(t *testing.T) { 305 setParallel(t) 306 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...) 307 defer cst1.close() 308 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...) 309 defer cst2.close() 310 311 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) 312 if err != nil { 313 t.Errorf("HTTP/1 request: %v", err) 314 return 315 } 316 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL) 317 if err != nil { 318 t.Errorf("HTTP/2 request: %v", err) 319 return 320 } 321 322 if fn := tt.EarlyCheckResponse; fn != nil { 323 fn("HTTP/1.1", res1) 324 fn("HTTP/2.0", res2) 325 } 326 327 tt.normalizeRes(t, res1, "HTTP/1.1") 328 tt.normalizeRes(t, res2, "HTTP/2.0") 329 res1body, res2body := res1.Body, res2.Body 330 331 eres1 := mostlyCopy(res1) 332 eres2 := mostlyCopy(res2) 333 if !reflect.DeepEqual(eres1, eres2) { 334 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v", 335 cst1.ts.URL, eres1, cst2.ts.URL, eres2) 336 } 337 if !reflect.DeepEqual(res1body, res2body) { 338 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body) 339 } 340 if fn := tt.CheckResponse; fn != nil { 341 res1.Body, res2.Body = res1body, res2body 342 fn("HTTP/1.1", res1) 343 fn("HTTP/2.0", res2) 344 } 345 } 346 347 func mostlyCopy(r *Response) *Response { 348 c := *r 349 c.Body = nil 350 c.TransferEncoding = nil 351 c.TLS = nil 352 c.Request = nil 353 return &c 354 } 355 356 type slurpResult struct { 357 io.ReadCloser 358 body []byte 359 err error 360 } 361 362 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) } 363 364 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) { 365 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" { 366 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0 367 } else { 368 t.Errorf("got %q response; want %q", res.Proto, wantProto) 369 } 370 slurp, err := io.ReadAll(res.Body) 371 372 res.Body.Close() 373 res.Body = slurpResult{ 374 ReadCloser: io.NopCloser(bytes.NewReader(slurp)), 375 body: slurp, 376 err: err, 377 } 378 for i, v := range res.Header["Date"] { 379 res.Header["Date"][i] = strings.Repeat("x", len(v)) 380 } 381 if res.Request == nil { 382 t.Errorf("for %s, no request", wantProto) 383 } 384 if (res.TLS != nil) != (wantProto == "HTTP/2.0") { 385 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil) 386 } 387 } 388 389 // Issue 13532 390 func TestH12_HeadContentLengthNoBody(t *testing.T) { 391 h12Compare{ 392 ReqFunc: (*Client).Head, 393 Handler: func(w ResponseWriter, r *Request) { 394 }, 395 }.run(t) 396 } 397 398 func TestH12_HeadContentLengthSmallBody(t *testing.T) { 399 h12Compare{ 400 ReqFunc: (*Client).Head, 401 Handler: func(w ResponseWriter, r *Request) { 402 io.WriteString(w, "small") 403 }, 404 }.run(t) 405 } 406 407 func TestH12_HeadContentLengthLargeBody(t *testing.T) { 408 h12Compare{ 409 ReqFunc: (*Client).Head, 410 Handler: func(w ResponseWriter, r *Request) { 411 chunk := strings.Repeat("x", 512<<10) 412 for i := 0; i < 10; i++ { 413 io.WriteString(w, chunk) 414 } 415 }, 416 }.run(t) 417 } 418 419 func TestH12_200NoBody(t *testing.T) { 420 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t) 421 } 422 423 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) } 424 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) } 425 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) } 426 427 func testH12_noBody(t *testing.T, status int) { 428 h12Compare{Handler: func(w ResponseWriter, r *Request) { 429 w.WriteHeader(status) 430 }}.run(t) 431 } 432 433 func TestH12_SmallBody(t *testing.T) { 434 h12Compare{Handler: func(w ResponseWriter, r *Request) { 435 io.WriteString(w, "small body") 436 }}.run(t) 437 } 438 439 func TestH12_ExplicitContentLength(t *testing.T) { 440 h12Compare{Handler: func(w ResponseWriter, r *Request) { 441 w.Header().Set("Content-Length", "3") 442 io.WriteString(w, "foo") 443 }}.run(t) 444 } 445 446 func TestH12_FlushBeforeBody(t *testing.T) { 447 h12Compare{Handler: func(w ResponseWriter, r *Request) { 448 w.(Flusher).Flush() 449 io.WriteString(w, "foo") 450 }}.run(t) 451 } 452 453 func TestH12_FlushMidBody(t *testing.T) { 454 h12Compare{Handler: func(w ResponseWriter, r *Request) { 455 io.WriteString(w, "foo") 456 w.(Flusher).Flush() 457 io.WriteString(w, "bar") 458 }}.run(t) 459 } 460 461 func TestH12_Head_ExplicitLen(t *testing.T) { 462 h12Compare{ 463 ReqFunc: (*Client).Head, 464 Handler: func(w ResponseWriter, r *Request) { 465 if r.Method != "HEAD" { 466 t.Errorf("unexpected method %q", r.Method) 467 } 468 w.Header().Set("Content-Length", "1235") 469 }, 470 }.run(t) 471 } 472 473 func TestH12_Head_ImplicitLen(t *testing.T) { 474 h12Compare{ 475 ReqFunc: (*Client).Head, 476 Handler: func(w ResponseWriter, r *Request) { 477 if r.Method != "HEAD" { 478 t.Errorf("unexpected method %q", r.Method) 479 } 480 io.WriteString(w, "foo") 481 }, 482 }.run(t) 483 } 484 485 func TestH12_HandlerWritesTooLittle(t *testing.T) { 486 h12Compare{ 487 Handler: func(w ResponseWriter, r *Request) { 488 w.Header().Set("Content-Length", "3") 489 io.WriteString(w, "12") // one byte short 490 }, 491 CheckResponse: func(proto string, res *Response) { 492 sr, ok := res.Body.(slurpResult) 493 if !ok { 494 t.Errorf("%s body is %T; want slurpResult", proto, res.Body) 495 return 496 } 497 if sr.err != io.ErrUnexpectedEOF { 498 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err) 499 } 500 if string(sr.body) != "12" { 501 t.Errorf("%s body = %q; want %q", proto, sr.body, "12") 502 } 503 }, 504 }.run(t) 505 } 506 507 // Tests that the HTTP/1 and HTTP/2 servers prevent handlers from 508 // writing more than they declared. This test does not test whether 509 // the transport deals with too much data, though, since the server 510 // doesn't make it possible to send bogus data. For those tests, see 511 // transport_test.go (for HTTP/1) or x/net/http2/transport_test.go 512 // (for HTTP/2). 513 func TestH12_HandlerWritesTooMuch(t *testing.T) { 514 h12Compare{ 515 Handler: func(w ResponseWriter, r *Request) { 516 w.Header().Set("Content-Length", "3") 517 w.(Flusher).Flush() 518 io.WriteString(w, "123") 519 w.(Flusher).Flush() 520 n, err := io.WriteString(w, "x") // too many 521 if n > 0 || err == nil { 522 t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err) 523 } 524 }, 525 }.run(t) 526 } 527 528 // Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip. 529 // Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298 530 func TestH12_AutoGzip(t *testing.T) { 531 h12Compare{ 532 Handler: func(w ResponseWriter, r *Request) { 533 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" { 534 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae) 535 } 536 w.Header().Set("Content-Encoding", "gzip") 537 gz := gzip.NewWriter(w) 538 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.") 539 gz.Close() 540 }, 541 }.run(t) 542 } 543 544 func TestH12_AutoGzip_Disabled(t *testing.T) { 545 h12Compare{ 546 Opts: []any{ 547 func(tr *Transport) { tr.DisableCompression = true }, 548 }, 549 Handler: func(w ResponseWriter, r *Request) { 550 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"]) 551 if ae := r.Header.Get("Accept-Encoding"); ae != "" { 552 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae) 553 } 554 }, 555 }.run(t) 556 } 557 558 // Test304Responses verifies that 304s don't declare that they're 559 // chunking in their response headers and aren't allowed to produce 560 // output. 561 func Test304Responses(t *testing.T) { run(t, test304Responses) } 562 func test304Responses(t *testing.T, mode testMode) { 563 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 564 w.WriteHeader(StatusNotModified) 565 _, err := w.Write([]byte("illegal body")) 566 if err != ErrBodyNotAllowed { 567 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) 568 } 569 })) 570 defer cst.close() 571 res, err := cst.c.Get(cst.ts.URL) 572 if err != nil { 573 t.Fatal(err) 574 } 575 if len(res.TransferEncoding) > 0 { 576 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) 577 } 578 body, err := io.ReadAll(res.Body) 579 if err != nil { 580 t.Error(err) 581 } 582 if len(body) > 0 { 583 t.Errorf("got unexpected body %q", string(body)) 584 } 585 } 586 587 func TestH12_ServerEmptyContentLength(t *testing.T) { 588 h12Compare{ 589 Handler: func(w ResponseWriter, r *Request) { 590 w.Header()["Content-Type"] = []string{""} 591 io.WriteString(w, "<html><body>hi</body></html>") 592 }, 593 }.run(t) 594 } 595 596 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) { 597 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4) 598 } 599 600 func TestH12_RequestContentLength_Known_Zero(t *testing.T) { 601 h12requestContentLength(t, func() io.Reader { return nil }, 0) 602 } 603 604 func TestH12_RequestContentLength_Unknown(t *testing.T) { 605 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1) 606 } 607 608 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) { 609 h12Compare{ 610 Handler: func(w ResponseWriter, r *Request) { 611 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength)) 612 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength) 613 }, 614 ReqFunc: func(c *Client, url string) (*Response, error) { 615 return c.Post(url, "text/plain", bodyfn()) 616 }, 617 CheckResponse: func(proto string, res *Response) { 618 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want { 619 t.Errorf("Proto %q got length %q; want %q", proto, got, want) 620 } 621 }, 622 }.run(t) 623 } 624 625 // Tests that closing the Request.Cancel channel also while still 626 // reading the response body. Issue 13159. 627 func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) } 628 func testCancelRequestMidBody(t *testing.T, mode testMode) { 629 unblock := make(chan bool) 630 didFlush := make(chan bool, 1) 631 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 632 io.WriteString(w, "Hello") 633 w.(Flusher).Flush() 634 didFlush <- true 635 <-unblock 636 io.WriteString(w, ", world.") 637 })) 638 defer close(unblock) 639 640 req, _ := NewRequest("GET", cst.ts.URL, nil) 641 cancel := make(chan struct{}) 642 req.Cancel = cancel 643 644 res, err := cst.c.Do(req) 645 if err != nil { 646 t.Fatal(err) 647 } 648 defer res.Body.Close() 649 <-didFlush 650 651 // Read a bit before we cancel. (Issue 13626) 652 // We should have "Hello" at least sitting there. 653 firstRead := make([]byte, 10) 654 n, err := res.Body.Read(firstRead) 655 if err != nil { 656 t.Fatal(err) 657 } 658 firstRead = firstRead[:n] 659 660 close(cancel) 661 662 rest, err := io.ReadAll(res.Body) 663 all := string(firstRead) + string(rest) 664 if all != "Hello" { 665 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest) 666 } 667 if err != ExportErrRequestCanceled { 668 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled) 669 } 670 } 671 672 // Tests that clients can send trailers to a server and that the server can read them. 673 func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) } 674 func testTrailersClientToServer(t *testing.T, mode testMode) { 675 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 676 var decl []string 677 for k := range r.Trailer { 678 decl = append(decl, k) 679 } 680 sort.Strings(decl) 681 682 slurp, err := io.ReadAll(r.Body) 683 if err != nil { 684 t.Errorf("Server reading request body: %v", err) 685 } 686 if string(slurp) != "foo" { 687 t.Errorf("Server read request body %q; want foo", slurp) 688 } 689 if r.Trailer == nil { 690 io.WriteString(w, "nil Trailer") 691 } else { 692 fmt.Fprintf(w, "decl: %v, vals: %s, %s", 693 decl, 694 r.Trailer.Get("Client-Trailer-A"), 695 r.Trailer.Get("Client-Trailer-B")) 696 } 697 })) 698 699 var req *Request 700 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader( 701 eofReaderFunc(func() { 702 req.Trailer["Client-Trailer-A"] = []string{"valuea"} 703 }), 704 strings.NewReader("foo"), 705 eofReaderFunc(func() { 706 req.Trailer["Client-Trailer-B"] = []string{"valueb"} 707 }), 708 )) 709 req.Trailer = Header{ 710 "Client-Trailer-A": nil, // to be set later 711 "Client-Trailer-B": nil, // to be set later 712 } 713 req.ContentLength = -1 714 res, err := cst.c.Do(req) 715 if err != nil { 716 t.Fatal(err) 717 } 718 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil { 719 t.Error(err) 720 } 721 } 722 723 // Tests that servers send trailers to a client and that the client can read them. 724 func TestTrailersServerToClient(t *testing.T) { 725 run(t, func(t *testing.T, mode testMode) { 726 testTrailersServerToClient(t, mode, false) 727 }) 728 } 729 func TestTrailersServerToClientFlush(t *testing.T) { 730 run(t, func(t *testing.T, mode testMode) { 731 testTrailersServerToClient(t, mode, true) 732 }) 733 } 734 735 func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) { 736 const body = "Some body" 737 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 738 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") 739 w.Header().Add("Trailer", "Server-Trailer-C") 740 741 io.WriteString(w, body) 742 if flush { 743 w.(Flusher).Flush() 744 } 745 746 // How handlers set Trailers: declare it ahead of time 747 // with the Trailer header, and then mutate the 748 // Header() of those values later, after the response 749 // has been written (we wrote to w above). 750 w.Header().Set("Server-Trailer-A", "valuea") 751 w.Header().Set("Server-Trailer-C", "valuec") // skipping B 752 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted") 753 })) 754 755 res, err := cst.c.Get(cst.ts.URL) 756 if err != nil { 757 t.Fatal(err) 758 } 759 760 wantHeader := Header{ 761 "Content-Type": {"text/plain; charset=utf-8"}, 762 } 763 wantLen := -1 764 if mode == http2Mode && !flush { 765 // In HTTP/1.1, any use of trailers forces HTTP/1.1 766 // chunking and a flush at the first write. That's 767 // unnecessary with HTTP/2's framing, so the server 768 // is able to calculate the length while still sending 769 // trailers afterwards. 770 wantLen = len(body) 771 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)} 772 } 773 if res.ContentLength != int64(wantLen) { 774 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen) 775 } 776 777 delete(res.Header, "Date") // irrelevant for test 778 if !reflect.DeepEqual(res.Header, wantHeader) { 779 t.Errorf("Header = %v; want %v", res.Header, wantHeader) 780 } 781 782 if got, want := res.Trailer, (Header{ 783 "Server-Trailer-A": nil, 784 "Server-Trailer-B": nil, 785 "Server-Trailer-C": nil, 786 }); !reflect.DeepEqual(got, want) { 787 t.Errorf("Trailer before body read = %v; want %v", got, want) 788 } 789 790 if err := wantBody(res, nil, body); err != nil { 791 t.Fatal(err) 792 } 793 794 if got, want := res.Trailer, (Header{ 795 "Server-Trailer-A": {"valuea"}, 796 "Server-Trailer-B": nil, 797 "Server-Trailer-C": {"valuec"}, 798 }); !reflect.DeepEqual(got, want) { 799 t.Errorf("Trailer after body read = %v; want %v", got, want) 800 } 801 } 802 803 // Don't allow a Body.Read after Body.Close. Issue 13648. 804 func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) } 805 func testResponseBodyReadAfterClose(t *testing.T, mode testMode) { 806 const body = "Some body" 807 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 808 io.WriteString(w, body) 809 })) 810 res, err := cst.c.Get(cst.ts.URL) 811 if err != nil { 812 t.Fatal(err) 813 } 814 res.Body.Close() 815 data, err := io.ReadAll(res.Body) 816 if len(data) != 0 || err == nil { 817 t.Fatalf("ReadAll returned %q, %v; want error", data, err) 818 } 819 } 820 821 func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) } 822 func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) { 823 const reqBody = "some request body" 824 const resBody = "some response body" 825 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 826 var wg sync.WaitGroup 827 wg.Add(2) 828 didRead := make(chan bool, 1) 829 // Read in one goroutine. 830 go func() { 831 defer wg.Done() 832 data, err := io.ReadAll(r.Body) 833 if string(data) != reqBody { 834 t.Errorf("Handler read %q; want %q", data, reqBody) 835 } 836 if err != nil { 837 t.Errorf("Handler Read: %v", err) 838 } 839 didRead <- true 840 }() 841 // Write in another goroutine. 842 go func() { 843 defer wg.Done() 844 if mode != http2Mode { 845 // our HTTP/1 implementation intentionally 846 // doesn't permit writes during read (mostly 847 // due to it being undefined); if that is ever 848 // relaxed, change this. 849 <-didRead 850 } 851 io.WriteString(w, resBody) 852 }() 853 wg.Wait() 854 })) 855 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody)) 856 req.Header.Add("Expect", "100-continue") // just to complicate things 857 res, err := cst.c.Do(req) 858 if err != nil { 859 t.Fatal(err) 860 } 861 data, err := io.ReadAll(res.Body) 862 defer res.Body.Close() 863 if err != nil { 864 t.Fatal(err) 865 } 866 if string(data) != resBody { 867 t.Errorf("read %q; want %q", data, resBody) 868 } 869 } 870 871 func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) } 872 func testConnectRequest(t *testing.T, mode testMode) { 873 gotc := make(chan *Request, 1) 874 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 875 gotc <- r 876 })) 877 878 u, err := url.Parse(cst.ts.URL) 879 if err != nil { 880 t.Fatal(err) 881 } 882 883 tests := []struct { 884 req *Request 885 want string 886 }{ 887 { 888 req: &Request{ 889 Method: "CONNECT", 890 Header: Header{}, 891 URL: u, 892 }, 893 want: u.Host, 894 }, 895 { 896 req: &Request{ 897 Method: "CONNECT", 898 Header: Header{}, 899 URL: u, 900 Host: "example.com:123", 901 }, 902 want: "example.com:123", 903 }, 904 } 905 906 for i, tt := range tests { 907 res, err := cst.c.Do(tt.req) 908 if err != nil { 909 t.Errorf("%d. RoundTrip = %v", i, err) 910 continue 911 } 912 res.Body.Close() 913 req := <-gotc 914 if req.Method != "CONNECT" { 915 t.Errorf("method = %q; want CONNECT", req.Method) 916 } 917 if req.Host != tt.want { 918 t.Errorf("Host = %q; want %q", req.Host, tt.want) 919 } 920 if req.URL.Host != tt.want { 921 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want) 922 } 923 } 924 } 925 926 func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) } 927 func testTransportUserAgent(t *testing.T, mode testMode) { 928 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 929 fmt.Fprintf(w, "%q", r.Header["User-Agent"]) 930 })) 931 932 either := func(a, b string) string { 933 if mode == http2Mode { 934 return b 935 } 936 return a 937 } 938 939 tests := []struct { 940 setup func(*Request) 941 want string 942 }{ 943 { 944 func(r *Request) {}, 945 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`), 946 }, 947 { 948 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") }, 949 `["foo/1.2.3"]`, 950 }, 951 { 952 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} }, 953 `["single"]`, 954 }, 955 { 956 func(r *Request) { r.Header.Set("User-Agent", "") }, 957 `[]`, 958 }, 959 { 960 func(r *Request) { r.Header["User-Agent"] = nil }, 961 `[]`, 962 }, 963 } 964 for i, tt := range tests { 965 req, _ := NewRequest("GET", cst.ts.URL, nil) 966 tt.setup(req) 967 res, err := cst.c.Do(req) 968 if err != nil { 969 t.Errorf("%d. RoundTrip = %v", i, err) 970 continue 971 } 972 slurp, err := io.ReadAll(res.Body) 973 res.Body.Close() 974 if err != nil { 975 t.Errorf("%d. read body = %v", i, err) 976 continue 977 } 978 if string(slurp) != tt.want { 979 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want) 980 } 981 } 982 } 983 984 func TestStarRequestMethod(t *testing.T) { 985 for _, method := range []string{"FOO", "OPTIONS"} { 986 t.Run(method, func(t *testing.T) { 987 run(t, func(t *testing.T, mode testMode) { 988 testStarRequest(t, method, mode) 989 }) 990 }) 991 } 992 } 993 func testStarRequest(t *testing.T, method string, mode testMode) { 994 gotc := make(chan *Request, 1) 995 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 996 w.Header().Set("foo", "bar") 997 gotc <- r 998 w.(Flusher).Flush() 999 })) 1000 1001 u, err := url.Parse(cst.ts.URL) 1002 if err != nil { 1003 t.Fatal(err) 1004 } 1005 u.Path = "*" 1006 1007 req := &Request{ 1008 Method: method, 1009 Header: Header{}, 1010 URL: u, 1011 } 1012 1013 res, err := cst.c.Do(req) 1014 if err != nil { 1015 t.Fatalf("RoundTrip = %v", err) 1016 } 1017 res.Body.Close() 1018 1019 wantFoo := "bar" 1020 wantLen := int64(-1) 1021 if method == "OPTIONS" { 1022 wantFoo = "" 1023 wantLen = 0 1024 } 1025 if res.StatusCode != 200 { 1026 t.Errorf("status code = %v; want %d", res.Status, 200) 1027 } 1028 if res.ContentLength != wantLen { 1029 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen) 1030 } 1031 if got := res.Header.Get("foo"); got != wantFoo { 1032 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo) 1033 } 1034 select { 1035 case req = <-gotc: 1036 default: 1037 req = nil 1038 } 1039 if req == nil { 1040 if method != "OPTIONS" { 1041 t.Fatalf("handler never got request") 1042 } 1043 return 1044 } 1045 if req.Method != method { 1046 t.Errorf("method = %q; want %q", req.Method, method) 1047 } 1048 if req.URL.Path != "*" { 1049 t.Errorf("URL.Path = %q; want *", req.URL.Path) 1050 } 1051 if req.RequestURI != "*" { 1052 t.Errorf("RequestURI = %q; want *", req.RequestURI) 1053 } 1054 } 1055 1056 // Issue 13957 1057 func TestTransportDiscardsUnneededConns(t *testing.T) { 1058 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode}) 1059 } 1060 func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) { 1061 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1062 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) 1063 })) 1064 defer cst.close() 1065 1066 var numOpen, numClose int32 // atomic 1067 1068 tlsConfig := &tls.Config{InsecureSkipVerify: true} 1069 tr := &Transport{ 1070 TLSClientConfig: tlsConfig, 1071 DialTLS: func(_, addr string) (net.Conn, error) { 1072 time.Sleep(10 * time.Millisecond) 1073 rc, err := net.Dial("tcp", addr) 1074 if err != nil { 1075 return nil, err 1076 } 1077 atomic.AddInt32(&numOpen, 1) 1078 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }} 1079 return tls.Client(c, tlsConfig), nil 1080 }, 1081 } 1082 if err := ExportHttp2ConfigureTransport(tr); err != nil { 1083 t.Fatal(err) 1084 } 1085 defer tr.CloseIdleConnections() 1086 1087 c := &Client{Transport: tr} 1088 1089 const N = 10 1090 gotBody := make(chan string, N) 1091 var wg sync.WaitGroup 1092 for i := 0; i < N; i++ { 1093 wg.Add(1) 1094 go func() { 1095 defer wg.Done() 1096 resp, err := c.Get(cst.ts.URL) 1097 if err != nil { 1098 // Try to work around spurious connection reset on loaded system. 1099 // See golang.org/issue/33585 and golang.org/issue/36797. 1100 time.Sleep(10 * time.Millisecond) 1101 resp, err = c.Get(cst.ts.URL) 1102 if err != nil { 1103 t.Errorf("Get: %v", err) 1104 return 1105 } 1106 } 1107 defer resp.Body.Close() 1108 slurp, err := io.ReadAll(resp.Body) 1109 if err != nil { 1110 t.Error(err) 1111 } 1112 gotBody <- string(slurp) 1113 }() 1114 } 1115 wg.Wait() 1116 close(gotBody) 1117 1118 var last string 1119 for got := range gotBody { 1120 if last == "" { 1121 last = got 1122 continue 1123 } 1124 if got != last { 1125 t.Errorf("Response body changed: %q -> %q", last, got) 1126 } 1127 } 1128 1129 var open, close int32 1130 for i := 0; i < 150; i++ { 1131 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose) 1132 if open < 1 { 1133 t.Fatalf("open = %d; want at least", open) 1134 } 1135 if close == open-1 { 1136 // Success 1137 return 1138 } 1139 time.Sleep(10 * time.Millisecond) 1140 } 1141 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1) 1142 } 1143 1144 // tests that Transport doesn't retain a pointer to the provided request. 1145 func TestTransportGCRequest(t *testing.T) { 1146 run(t, func(t *testing.T, mode testMode) { 1147 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) }) 1148 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) }) 1149 }) 1150 } 1151 func testTransportGCRequest(t *testing.T, mode testMode, body bool) { 1152 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1153 io.ReadAll(r.Body) 1154 if body { 1155 io.WriteString(w, "Hello.") 1156 } 1157 })) 1158 1159 didGC := make(chan struct{}) 1160 (func() { 1161 body := strings.NewReader("some body") 1162 req, _ := NewRequest("POST", cst.ts.URL, body) 1163 runtime.SetFinalizer(req, func(*Request) { close(didGC) }) 1164 res, err := cst.c.Do(req) 1165 if err != nil { 1166 t.Fatal(err) 1167 } 1168 if _, err := io.ReadAll(res.Body); err != nil { 1169 t.Fatal(err) 1170 } 1171 if err := res.Body.Close(); err != nil { 1172 t.Fatal(err) 1173 } 1174 })() 1175 timeout := time.NewTimer(5 * time.Second) 1176 defer timeout.Stop() 1177 for { 1178 select { 1179 case <-didGC: 1180 return 1181 case <-time.After(100 * time.Millisecond): 1182 runtime.GC() 1183 case <-timeout.C: 1184 t.Fatal("never saw GC of request") 1185 } 1186 } 1187 } 1188 1189 func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) } 1190 func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) { 1191 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1192 fmt.Fprintf(w, "Handler saw headers: %q", r.Header) 1193 }), optQuietLog) 1194 cst.tr.DisableKeepAlives = true 1195 1196 tests := []struct { 1197 key, val string 1198 ok bool 1199 }{ 1200 {"Foo", "capital-key", true}, // verify h2 allows capital keys 1201 {"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed 1202 {"Foo", "two\nlines", false}, // \n byte in value not allowed 1203 {"bogus\nkey", "v", false}, // \n byte also not allowed in key 1204 {"A space", "v", false}, // spaces in keys not allowed 1205 {"имя", "v", false}, // key must be ascii 1206 {"name", "валю", true}, // value may be non-ascii 1207 {"", "v", false}, // key must be non-empty 1208 {"k", "", true}, // value may be empty 1209 } 1210 for _, tt := range tests { 1211 dialedc := make(chan bool, 1) 1212 cst.tr.Dial = func(netw, addr string) (net.Conn, error) { 1213 dialedc <- true 1214 return net.Dial(netw, addr) 1215 } 1216 req, _ := NewRequest("GET", cst.ts.URL, nil) 1217 req.Header[tt.key] = []string{tt.val} 1218 res, err := cst.c.Do(req) 1219 var body []byte 1220 if err == nil { 1221 body, _ = io.ReadAll(res.Body) 1222 res.Body.Close() 1223 } 1224 var dialed bool 1225 select { 1226 case <-dialedc: 1227 dialed = true 1228 default: 1229 } 1230 1231 if !tt.ok && dialed { 1232 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body) 1233 } else if (err == nil) != tt.ok { 1234 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok) 1235 } 1236 } 1237 } 1238 1239 func TestInterruptWithPanic(t *testing.T) { 1240 run(t, func(t *testing.T, mode testMode) { 1241 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") }) 1242 t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) }) 1243 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) }) 1244 }, testNotParallel) 1245 } 1246 func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) { 1247 const msg = "hello" 1248 1249 testDone := make(chan struct{}) 1250 defer close(testDone) 1251 1252 var errorLog lockedBytesBuffer 1253 gotHeaders := make(chan bool, 1) 1254 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1255 io.WriteString(w, msg) 1256 w.(Flusher).Flush() 1257 1258 select { 1259 case <-gotHeaders: 1260 case <-testDone: 1261 } 1262 panic(panicValue) 1263 }), func(ts *httptest.Server) { 1264 ts.Config.ErrorLog = log.New(&errorLog, "", 0) 1265 }) 1266 res, err := cst.c.Get(cst.ts.URL) 1267 if err != nil { 1268 t.Fatal(err) 1269 } 1270 gotHeaders <- true 1271 defer res.Body.Close() 1272 slurp, err := io.ReadAll(res.Body) 1273 if string(slurp) != msg { 1274 t.Errorf("client read %q; want %q", slurp, msg) 1275 } 1276 if err == nil { 1277 t.Errorf("client read all successfully; want some error") 1278 } 1279 logOutput := func() string { 1280 errorLog.Lock() 1281 defer errorLog.Unlock() 1282 return errorLog.String() 1283 } 1284 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler 1285 1286 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { 1287 gotLog := logOutput() 1288 if !wantStackLogged { 1289 if gotLog == "" { 1290 return true 1291 } 1292 t.Fatalf("want no log output; got: %s", gotLog) 1293 } 1294 if gotLog == "" { 1295 if d > 0 { 1296 t.Logf("wanted a stack trace logged; got nothing after %v", d) 1297 } 1298 return false 1299 } 1300 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 { 1301 if d > 0 { 1302 t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog) 1303 } 1304 return false 1305 } 1306 return true 1307 }) 1308 } 1309 1310 type lockedBytesBuffer struct { 1311 sync.Mutex 1312 bytes.Buffer 1313 } 1314 1315 func (b *lockedBytesBuffer) Write(p []byte) (int, error) { 1316 b.Lock() 1317 defer b.Unlock() 1318 return b.Buffer.Write(p) 1319 } 1320 1321 // Issue 15366 1322 func TestH12_AutoGzipWithDumpResponse(t *testing.T) { 1323 h12Compare{ 1324 Handler: func(w ResponseWriter, r *Request) { 1325 h := w.Header() 1326 h.Set("Content-Encoding", "gzip") 1327 h.Set("Content-Length", "23") 1328 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00") 1329 }, 1330 EarlyCheckResponse: func(proto string, res *Response) { 1331 if !res.Uncompressed { 1332 t.Errorf("%s: expected Uncompressed to be set", proto) 1333 } 1334 dump, err := httputil.DumpResponse(res, true) 1335 if err != nil { 1336 t.Errorf("%s: DumpResponse: %v", proto, err) 1337 return 1338 } 1339 if strings.Contains(string(dump), "Connection: close") { 1340 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump) 1341 } 1342 if !strings.Contains(string(dump), "FOO") { 1343 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump) 1344 } 1345 }, 1346 }.run(t) 1347 } 1348 1349 // Issue 14607 1350 func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) } 1351 func testCloseIdleConnections(t *testing.T, mode testMode) { 1352 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1353 w.Header().Set("X-Addr", r.RemoteAddr) 1354 })) 1355 get := func() string { 1356 res, err := cst.c.Get(cst.ts.URL) 1357 if err != nil { 1358 t.Fatal(err) 1359 } 1360 res.Body.Close() 1361 v := res.Header.Get("X-Addr") 1362 if v == "" { 1363 t.Fatal("didn't get X-Addr") 1364 } 1365 return v 1366 } 1367 a1 := get() 1368 cst.tr.CloseIdleConnections() 1369 a2 := get() 1370 if a1 == a2 { 1371 t.Errorf("didn't close connection") 1372 } 1373 } 1374 1375 type noteCloseConn struct { 1376 net.Conn 1377 closeFunc func() 1378 } 1379 1380 func (x noteCloseConn) Close() error { 1381 x.closeFunc() 1382 return x.Conn.Close() 1383 } 1384 1385 type testErrorReader struct{ t *testing.T } 1386 1387 func (r testErrorReader) Read(p []byte) (n int, err error) { 1388 r.t.Error("unexpected Read call") 1389 return 0, io.EOF 1390 } 1391 1392 func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) } 1393 func testNoSniffExpectRequestBody(t *testing.T, mode testMode) { 1394 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1395 w.WriteHeader(StatusUnauthorized) 1396 })) 1397 1398 // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. 1399 cst.tr.ExpectContinueTimeout = 10 * time.Second 1400 1401 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t}) 1402 if err != nil { 1403 t.Fatal(err) 1404 } 1405 req.ContentLength = 0 // so transport is tempted to sniff it 1406 req.Header.Set("Expect", "100-continue") 1407 res, err := cst.tr.RoundTrip(req) 1408 if err != nil { 1409 t.Fatal(err) 1410 } 1411 defer res.Body.Close() 1412 if res.StatusCode != StatusUnauthorized { 1413 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized) 1414 } 1415 } 1416 1417 func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) } 1418 func testServerUndeclaredTrailers(t *testing.T, mode testMode) { 1419 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1420 w.Header().Set("Foo", "Bar") 1421 w.Header().Set("Trailer:Foo", "Baz") 1422 w.(Flusher).Flush() 1423 w.Header().Add("Trailer:Foo", "Baz2") 1424 w.Header().Set("Trailer:Bar", "Quux") 1425 })) 1426 res, err := cst.c.Get(cst.ts.URL) 1427 if err != nil { 1428 t.Fatal(err) 1429 } 1430 if _, err := io.Copy(io.Discard, res.Body); err != nil { 1431 t.Fatal(err) 1432 } 1433 res.Body.Close() 1434 delete(res.Header, "Date") 1435 delete(res.Header, "Content-Type") 1436 1437 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) { 1438 t.Errorf("Header = %#v; want %#v", res.Header, want) 1439 } 1440 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) { 1441 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want) 1442 } 1443 } 1444 1445 func TestBadResponseAfterReadingBody(t *testing.T) { 1446 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode}) 1447 } 1448 func testBadResponseAfterReadingBody(t *testing.T, mode testMode) { 1449 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1450 _, err := io.Copy(io.Discard, r.Body) 1451 if err != nil { 1452 t.Fatal(err) 1453 } 1454 c, _, err := w.(Hijacker).Hijack() 1455 if err != nil { 1456 t.Fatal(err) 1457 } 1458 defer c.Close() 1459 fmt.Fprintln(c, "some bogus crap") 1460 })) 1461 1462 closes := 0 1463 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) 1464 if err == nil { 1465 res.Body.Close() 1466 t.Fatal("expected an error to be returned from Post") 1467 } 1468 if closes != 1 { 1469 t.Errorf("closes = %d; want 1", closes) 1470 } 1471 } 1472 1473 func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) } 1474 func testWriteHeader0(t *testing.T, mode testMode) { 1475 gotpanic := make(chan bool, 1) 1476 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1477 defer close(gotpanic) 1478 defer func() { 1479 if e := recover(); e != nil { 1480 got := fmt.Sprintf("%T, %v", e, e) 1481 want := "string, invalid WriteHeader code 0" 1482 if got != want { 1483 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want) 1484 } 1485 gotpanic <- true 1486 1487 // Set an explicit 503. This also tests that the WriteHeader call panics 1488 // before it recorded that an explicit value was set and that bogus 1489 // value wasn't stuck. 1490 w.WriteHeader(503) 1491 } 1492 }() 1493 w.WriteHeader(0) 1494 })) 1495 res, err := cst.c.Get(cst.ts.URL) 1496 if err != nil { 1497 t.Fatal(err) 1498 } 1499 if res.StatusCode != 503 { 1500 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status) 1501 } 1502 if !<-gotpanic { 1503 t.Error("expected panic in handler") 1504 } 1505 } 1506 1507 // Issue 23010: don't be super strict checking WriteHeader's code if 1508 // it's not even valid to call WriteHeader then anyway. 1509 func TestWriteHeaderNoCodeCheck(t *testing.T) { 1510 run(t, func(t *testing.T, mode testMode) { 1511 testWriteHeaderAfterWrite(t, mode, false) 1512 }) 1513 } 1514 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { 1515 testWriteHeaderAfterWrite(t, http1Mode, true) 1516 } 1517 func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) { 1518 var errorLog lockedBytesBuffer 1519 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1520 if hijack { 1521 conn, _, _ := w.(Hijacker).Hijack() 1522 defer conn.Close() 1523 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo")) 1524 w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 1525 conn.Write([]byte("bar")) 1526 return 1527 } 1528 io.WriteString(w, "foo") 1529 w.(Flusher).Flush() 1530 w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 1531 io.WriteString(w, "bar") 1532 }), func(ts *httptest.Server) { 1533 ts.Config.ErrorLog = log.New(&errorLog, "", 0) 1534 }) 1535 res, err := cst.c.Get(cst.ts.URL) 1536 if err != nil { 1537 t.Fatal(err) 1538 } 1539 defer res.Body.Close() 1540 body, err := io.ReadAll(res.Body) 1541 if err != nil { 1542 t.Fatal(err) 1543 } 1544 if got, want := string(body), "foobar"; got != want { 1545 t.Errorf("got = %q; want %q", got, want) 1546 } 1547 1548 // Also check the stderr output: 1549 if mode == http2Mode { 1550 // TODO: also emit this log message for HTTP/2? 1551 // We historically haven't, so don't check. 1552 return 1553 } 1554 gotLog := strings.TrimSpace(errorLog.String()) 1555 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:" 1556 if hijack { 1557 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:" 1558 } 1559 if !strings.HasPrefix(gotLog, wantLog) { 1560 t.Errorf("stderr output = %q; want %q", gotLog, wantLog) 1561 } 1562 } 1563 1564 func TestBidiStreamReverseProxy(t *testing.T) { 1565 run(t, testBidiStreamReverseProxy, []testMode{http2Mode}) 1566 } 1567 func testBidiStreamReverseProxy(t *testing.T, mode testMode) { 1568 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1569 if _, err := io.Copy(w, r.Body); err != nil { 1570 log.Printf("bidi backend copy: %v", err) 1571 } 1572 })) 1573 1574 backURL, err := url.Parse(backend.ts.URL) 1575 if err != nil { 1576 t.Fatal(err) 1577 } 1578 rp := httputil.NewSingleHostReverseProxy(backURL) 1579 rp.Transport = backend.tr 1580 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1581 rp.ServeHTTP(w, r) 1582 })) 1583 1584 bodyRes := make(chan any, 1) // error or hash.Hash 1585 pr, pw := io.Pipe() 1586 req, _ := NewRequest("PUT", proxy.ts.URL, pr) 1587 const size = 4 << 20 1588 go func() { 1589 h := sha1.New() 1590 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size) 1591 go pw.Close() 1592 if err != nil { 1593 bodyRes <- err 1594 } else { 1595 bodyRes <- h 1596 } 1597 }() 1598 res, err := backend.c.Do(req) 1599 if err != nil { 1600 t.Fatal(err) 1601 } 1602 defer res.Body.Close() 1603 hgot := sha1.New() 1604 n, err := io.Copy(hgot, res.Body) 1605 if err != nil { 1606 t.Fatal(err) 1607 } 1608 if n != size { 1609 t.Fatalf("got %d bytes; want %d", n, size) 1610 } 1611 select { 1612 case v := <-bodyRes: 1613 switch v := v.(type) { 1614 default: 1615 t.Fatalf("body copy: %v", err) 1616 case hash.Hash: 1617 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) { 1618 t.Errorf("written bytes didn't match received bytes") 1619 } 1620 } 1621 case <-time.After(10 * time.Second): 1622 t.Fatal("timeout") 1623 } 1624 1625 } 1626 1627 // Always use HTTP/1.1 for WebSocket upgrades. 1628 func TestH12_WebSocketUpgrade(t *testing.T) { 1629 h12Compare{ 1630 Handler: func(w ResponseWriter, r *Request) { 1631 h := w.Header() 1632 h.Set("Foo", "bar") 1633 }, 1634 ReqFunc: func(c *Client, url string) (*Response, error) { 1635 req, _ := NewRequest("GET", url, nil) 1636 req.Header.Set("Connection", "Upgrade") 1637 req.Header.Set("Upgrade", "WebSocket") 1638 return c.Do(req) 1639 }, 1640 EarlyCheckResponse: func(proto string, res *Response) { 1641 if res.Proto != "HTTP/1.1" { 1642 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto) 1643 } 1644 res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0 1645 }, 1646 }.run(t) 1647 } 1648 1649 func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) } 1650 func testIdentityTransferEncoding(t *testing.T, mode testMode) { 1651 const body = "body" 1652 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1653 gotBody, _ := io.ReadAll(r.Body) 1654 if got, want := string(gotBody), body; got != want { 1655 t.Errorf("got request body = %q; want %q", got, want) 1656 } 1657 w.Header().Set("Transfer-Encoding", "identity") 1658 w.WriteHeader(StatusOK) 1659 w.(Flusher).Flush() 1660 io.WriteString(w, body) 1661 })) 1662 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body)) 1663 res, err := cst.c.Do(req) 1664 if err != nil { 1665 t.Fatal(err) 1666 } 1667 defer res.Body.Close() 1668 gotBody, err := io.ReadAll(res.Body) 1669 if err != nil { 1670 t.Fatal(err) 1671 } 1672 if got, want := string(gotBody), body; got != want { 1673 t.Errorf("got response body = %q; want %q", got, want) 1674 } 1675 } 1676 1677 func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) } 1678 func testEarlyHintsRequest(t *testing.T, mode testMode) { 1679 var wg sync.WaitGroup 1680 wg.Add(1) 1681 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1682 h := w.Header() 1683 1684 h.Add("Content-Length", "123") // must be ignored 1685 h.Add("Link", "</style.css>; rel=preload; as=style") 1686 h.Add("Link", "</script.js>; rel=preload; as=script") 1687 w.WriteHeader(StatusEarlyHints) 1688 1689 wg.Wait() 1690 1691 h.Add("Link", "</foo.js>; rel=preload; as=script") 1692 w.WriteHeader(StatusEarlyHints) 1693 1694 w.Write([]byte("Hello")) 1695 })) 1696 1697 checkLinkHeaders := func(t *testing.T, expected, got []string) { 1698 t.Helper() 1699 1700 if len(expected) != len(got) { 1701 t.Errorf("got %d expected %d", len(got), len(expected)) 1702 } 1703 1704 for i := range expected { 1705 if expected[i] != got[i] { 1706 t.Errorf("got %q expected %q", got[i], expected[i]) 1707 } 1708 } 1709 } 1710 1711 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) { 1712 t.Helper() 1713 1714 for _, h := range []string{"Content-Length", "Transfer-Encoding"} { 1715 if v, ok := header[h]; ok { 1716 t.Errorf("%s is %q; must not be sent", h, v) 1717 } 1718 } 1719 } 1720 1721 var respCounter uint8 1722 trace := &httptrace.ClientTrace{ 1723 Got1xxResponse: func(code int, header textproto.MIMEHeader) error { 1724 switch respCounter { 1725 case 0: 1726 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"]) 1727 checkExcludedHeaders(t, header) 1728 1729 wg.Done() 1730 case 1: 1731 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"]) 1732 checkExcludedHeaders(t, header) 1733 1734 default: 1735 t.Error("Unexpected 1xx response") 1736 } 1737 1738 respCounter++ 1739 1740 return nil 1741 }, 1742 } 1743 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil) 1744 1745 res, err := cst.c.Do(req) 1746 if err != nil { 1747 t.Fatal(err) 1748 } 1749 defer res.Body.Close() 1750 1751 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"]) 1752 if cl := res.Header.Get("Content-Length"); cl != "123" { 1753 t.Errorf("Content-Length is %q; want 123", cl) 1754 } 1755 1756 body, _ := io.ReadAll(res.Body) 1757 if string(body) != "Hello" { 1758 t.Errorf("Read body %q; want Hello", body) 1759 } 1760 }