github.com/mtsmfm/go/src@v0.0.0-20221020090648-44bdcb9f8fde/net/http/clientserver_test.go (about) 1 // Copyright 2015 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode. 6 7 package http_test 8 9 import ( 10 "bytes" 11 "compress/gzip" 12 "context" 13 "crypto/rand" 14 "crypto/sha1" 15 "crypto/tls" 16 "fmt" 17 "hash" 18 "io" 19 "log" 20 "net" 21 . "net/http" 22 "net/http/httptest" 23 "net/http/httptrace" 24 "net/http/httputil" 25 "net/textproto" 26 "net/url" 27 "os" 28 "reflect" 29 "runtime" 30 "sort" 31 "strings" 32 "sync" 33 "sync/atomic" 34 "testing" 35 "time" 36 ) 37 38 type testMode string 39 40 const ( 41 http1Mode = testMode("h1") // HTTP/1.1 42 https1Mode = testMode("https1") // HTTPS/1.1 43 http2Mode = testMode("h2") // HTTP/2 44 ) 45 46 type testNotParallelOpt struct{} 47 48 var ( 49 testNotParallel = testNotParallelOpt{} 50 ) 51 52 type TBRun[T any] interface { 53 testing.TB 54 Run(string, func(T)) bool 55 } 56 57 // run runs a client/server test in a variety of test configurations. 58 // 59 // Tests execute in HTTP/1.1 and HTTP/2 modes by default. 60 // To run in a different set of configurations, pass a []testMode option. 61 // 62 // Tests call t.Parallel() by default. 63 // To disable parallel execution, pass the testNotParallel option. 64 func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) { 65 t.Helper() 66 modes := []testMode{http1Mode, http2Mode} 67 parallel := true 68 for _, opt := range opts { 69 switch opt := opt.(type) { 70 case []testMode: 71 modes = opt 72 case testNotParallelOpt: 73 parallel = false 74 default: 75 t.Fatalf("unknown option type %T", opt) 76 } 77 } 78 if t, ok := any(t).(*testing.T); ok && parallel { 79 setParallel(t) 80 } 81 for _, mode := range modes { 82 t.Run(string(mode), func(t T) { 83 t.Helper() 84 if t, ok := any(t).(*testing.T); ok && parallel { 85 setParallel(t) 86 } 87 t.Cleanup(func() { 88 afterTest(t) 89 }) 90 f(t, mode) 91 }) 92 } 93 } 94 95 type clientServerTest struct { 96 t testing.TB 97 h2 bool 98 h Handler 99 ts *httptest.Server 100 tr *Transport 101 c *Client 102 } 103 104 func (t *clientServerTest) close() { 105 t.tr.CloseIdleConnections() 106 t.ts.Close() 107 } 108 109 func (t *clientServerTest) getURL(u string) string { 110 res, err := t.c.Get(u) 111 if err != nil { 112 t.t.Fatal(err) 113 } 114 defer res.Body.Close() 115 slurp, err := io.ReadAll(res.Body) 116 if err != nil { 117 t.t.Fatal(err) 118 } 119 return string(slurp) 120 } 121 122 func (t *clientServerTest) scheme() string { 123 if t.h2 { 124 return "https" 125 } 126 return "http" 127 } 128 129 var optQuietLog = func(ts *httptest.Server) { 130 ts.Config.ErrorLog = quietLog 131 } 132 133 func optWithServerLog(lg *log.Logger) func(*httptest.Server) { 134 return func(ts *httptest.Server) { 135 ts.Config.ErrorLog = lg 136 } 137 } 138 139 // newClientServerTest creates and starts an httptest.Server. 140 // 141 // The mode parameter selects the implementation to test: 142 // HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use 143 // the 'run' function, which will start a subtests for each tested mode. 144 // 145 // The vararg opts parameter can include functions to configure the 146 // test server or transport. 147 // 148 // func(*httptest.Server) // run before starting the server 149 // func(*http.Transport) 150 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest { 151 if mode == http2Mode { 152 CondSkipHTTP2(t) 153 } 154 cst := &clientServerTest{ 155 t: t, 156 h2: mode == http2Mode, 157 h: h, 158 } 159 cst.ts = httptest.NewUnstartedServer(h) 160 161 var transportFuncs []func(*Transport) 162 for _, opt := range opts { 163 switch opt := opt.(type) { 164 case func(*Transport): 165 transportFuncs = append(transportFuncs, opt) 166 case func(*httptest.Server): 167 opt(cst.ts) 168 default: 169 t.Fatalf("unhandled option type %T", opt) 170 } 171 } 172 173 switch mode { 174 case http1Mode: 175 cst.ts.Start() 176 case https1Mode: 177 cst.ts.StartTLS() 178 case http2Mode: 179 ExportHttp2ConfigureServer(cst.ts.Config, nil) 180 cst.ts.TLS = cst.ts.Config.TLSConfig 181 cst.ts.StartTLS() 182 default: 183 t.Fatalf("unknown test mode %v", mode) 184 } 185 cst.c = cst.ts.Client() 186 cst.tr = cst.c.Transport.(*Transport) 187 if mode == http2Mode { 188 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { 189 t.Fatal(err) 190 } 191 } 192 for _, f := range transportFuncs { 193 f(cst.tr) 194 } 195 t.Cleanup(func() { 196 cst.close() 197 }) 198 return cst 199 } 200 201 // Testing the newClientServerTest helper itself. 202 func TestNewClientServerTest(t *testing.T) { 203 run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode}) 204 } 205 func testNewClientServerTest(t *testing.T, mode testMode) { 206 var got struct { 207 sync.Mutex 208 proto string 209 hasTLS bool 210 } 211 h := HandlerFunc(func(w ResponseWriter, r *Request) { 212 got.Lock() 213 defer got.Unlock() 214 got.proto = r.Proto 215 got.hasTLS = r.TLS != nil 216 }) 217 cst := newClientServerTest(t, mode, h) 218 if _, err := cst.c.Head(cst.ts.URL); err != nil { 219 t.Fatal(err) 220 } 221 var wantProto string 222 var wantTLS bool 223 switch mode { 224 case http1Mode: 225 wantProto = "HTTP/1.1" 226 wantTLS = false 227 case https1Mode: 228 wantProto = "HTTP/1.1" 229 wantTLS = true 230 case http2Mode: 231 wantProto = "HTTP/2.0" 232 wantTLS = true 233 } 234 if got.proto != wantProto { 235 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto) 236 } 237 if got.hasTLS != wantTLS { 238 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS) 239 } 240 } 241 242 func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) } 243 func testChunkedResponseHeaders(t *testing.T, mode testMode) { 244 log.SetOutput(io.Discard) // is noisy otherwise 245 defer log.SetOutput(os.Stderr) 246 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 247 w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted 248 w.(Flusher).Flush() 249 fmt.Fprintf(w, "I am a chunked response.") 250 })) 251 252 res, err := cst.c.Get(cst.ts.URL) 253 if err != nil { 254 t.Fatalf("Get error: %v", err) 255 } 256 defer res.Body.Close() 257 if g, e := res.ContentLength, int64(-1); g != e { 258 t.Errorf("expected ContentLength of %d; got %d", e, g) 259 } 260 wantTE := []string{"chunked"} 261 if mode == http2Mode { 262 wantTE = nil 263 } 264 if !reflect.DeepEqual(res.TransferEncoding, wantTE) { 265 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE) 266 } 267 if got, haveCL := res.Header["Content-Length"]; haveCL { 268 t.Errorf("Unexpected Content-Length: %q", got) 269 } 270 } 271 272 type reqFunc func(c *Client, url string) (*Response, error) 273 274 // h12Compare is a test that compares HTTP/1 and HTTP/2 behavior 275 // against each other. 276 type h12Compare struct { 277 Handler func(ResponseWriter, *Request) // required 278 ReqFunc reqFunc // optional 279 CheckResponse func(proto string, res *Response) // optional 280 EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize 281 Opts []any 282 } 283 284 func (tt h12Compare) reqFunc() reqFunc { 285 if tt.ReqFunc == nil { 286 return (*Client).Get 287 } 288 return tt.ReqFunc 289 } 290 291 func (tt h12Compare) run(t *testing.T) { 292 setParallel(t) 293 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...) 294 defer cst1.close() 295 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...) 296 defer cst2.close() 297 298 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) 299 if err != nil { 300 t.Errorf("HTTP/1 request: %v", err) 301 return 302 } 303 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL) 304 if err != nil { 305 t.Errorf("HTTP/2 request: %v", err) 306 return 307 } 308 309 if fn := tt.EarlyCheckResponse; fn != nil { 310 fn("HTTP/1.1", res1) 311 fn("HTTP/2.0", res2) 312 } 313 314 tt.normalizeRes(t, res1, "HTTP/1.1") 315 tt.normalizeRes(t, res2, "HTTP/2.0") 316 res1body, res2body := res1.Body, res2.Body 317 318 eres1 := mostlyCopy(res1) 319 eres2 := mostlyCopy(res2) 320 if !reflect.DeepEqual(eres1, eres2) { 321 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v", 322 cst1.ts.URL, eres1, cst2.ts.URL, eres2) 323 } 324 if !reflect.DeepEqual(res1body, res2body) { 325 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body) 326 } 327 if fn := tt.CheckResponse; fn != nil { 328 res1.Body, res2.Body = res1body, res2body 329 fn("HTTP/1.1", res1) 330 fn("HTTP/2.0", res2) 331 } 332 } 333 334 func mostlyCopy(r *Response) *Response { 335 c := *r 336 c.Body = nil 337 c.TransferEncoding = nil 338 c.TLS = nil 339 c.Request = nil 340 return &c 341 } 342 343 type slurpResult struct { 344 io.ReadCloser 345 body []byte 346 err error 347 } 348 349 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) } 350 351 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) { 352 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" { 353 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0 354 } else { 355 t.Errorf("got %q response; want %q", res.Proto, wantProto) 356 } 357 slurp, err := io.ReadAll(res.Body) 358 359 res.Body.Close() 360 res.Body = slurpResult{ 361 ReadCloser: io.NopCloser(bytes.NewReader(slurp)), 362 body: slurp, 363 err: err, 364 } 365 for i, v := range res.Header["Date"] { 366 res.Header["Date"][i] = strings.Repeat("x", len(v)) 367 } 368 if res.Request == nil { 369 t.Errorf("for %s, no request", wantProto) 370 } 371 if (res.TLS != nil) != (wantProto == "HTTP/2.0") { 372 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil) 373 } 374 } 375 376 // Issue 13532 377 func TestH12_HeadContentLengthNoBody(t *testing.T) { 378 h12Compare{ 379 ReqFunc: (*Client).Head, 380 Handler: func(w ResponseWriter, r *Request) { 381 }, 382 }.run(t) 383 } 384 385 func TestH12_HeadContentLengthSmallBody(t *testing.T) { 386 h12Compare{ 387 ReqFunc: (*Client).Head, 388 Handler: func(w ResponseWriter, r *Request) { 389 io.WriteString(w, "small") 390 }, 391 }.run(t) 392 } 393 394 func TestH12_HeadContentLengthLargeBody(t *testing.T) { 395 h12Compare{ 396 ReqFunc: (*Client).Head, 397 Handler: func(w ResponseWriter, r *Request) { 398 chunk := strings.Repeat("x", 512<<10) 399 for i := 0; i < 10; i++ { 400 io.WriteString(w, chunk) 401 } 402 }, 403 }.run(t) 404 } 405 406 func TestH12_200NoBody(t *testing.T) { 407 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t) 408 } 409 410 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) } 411 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) } 412 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) } 413 414 func testH12_noBody(t *testing.T, status int) { 415 h12Compare{Handler: func(w ResponseWriter, r *Request) { 416 w.WriteHeader(status) 417 }}.run(t) 418 } 419 420 func TestH12_SmallBody(t *testing.T) { 421 h12Compare{Handler: func(w ResponseWriter, r *Request) { 422 io.WriteString(w, "small body") 423 }}.run(t) 424 } 425 426 func TestH12_ExplicitContentLength(t *testing.T) { 427 h12Compare{Handler: func(w ResponseWriter, r *Request) { 428 w.Header().Set("Content-Length", "3") 429 io.WriteString(w, "foo") 430 }}.run(t) 431 } 432 433 func TestH12_FlushBeforeBody(t *testing.T) { 434 h12Compare{Handler: func(w ResponseWriter, r *Request) { 435 w.(Flusher).Flush() 436 io.WriteString(w, "foo") 437 }}.run(t) 438 } 439 440 func TestH12_FlushMidBody(t *testing.T) { 441 h12Compare{Handler: func(w ResponseWriter, r *Request) { 442 io.WriteString(w, "foo") 443 w.(Flusher).Flush() 444 io.WriteString(w, "bar") 445 }}.run(t) 446 } 447 448 func TestH12_Head_ExplicitLen(t *testing.T) { 449 h12Compare{ 450 ReqFunc: (*Client).Head, 451 Handler: func(w ResponseWriter, r *Request) { 452 if r.Method != "HEAD" { 453 t.Errorf("unexpected method %q", r.Method) 454 } 455 w.Header().Set("Content-Length", "1235") 456 }, 457 }.run(t) 458 } 459 460 func TestH12_Head_ImplicitLen(t *testing.T) { 461 h12Compare{ 462 ReqFunc: (*Client).Head, 463 Handler: func(w ResponseWriter, r *Request) { 464 if r.Method != "HEAD" { 465 t.Errorf("unexpected method %q", r.Method) 466 } 467 io.WriteString(w, "foo") 468 }, 469 }.run(t) 470 } 471 472 func TestH12_HandlerWritesTooLittle(t *testing.T) { 473 h12Compare{ 474 Handler: func(w ResponseWriter, r *Request) { 475 w.Header().Set("Content-Length", "3") 476 io.WriteString(w, "12") // one byte short 477 }, 478 CheckResponse: func(proto string, res *Response) { 479 sr, ok := res.Body.(slurpResult) 480 if !ok { 481 t.Errorf("%s body is %T; want slurpResult", proto, res.Body) 482 return 483 } 484 if sr.err != io.ErrUnexpectedEOF { 485 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err) 486 } 487 if string(sr.body) != "12" { 488 t.Errorf("%s body = %q; want %q", proto, sr.body, "12") 489 } 490 }, 491 }.run(t) 492 } 493 494 // Tests that the HTTP/1 and HTTP/2 servers prevent handlers from 495 // writing more than they declared. This test does not test whether 496 // the transport deals with too much data, though, since the server 497 // doesn't make it possible to send bogus data. For those tests, see 498 // transport_test.go (for HTTP/1) or x/net/http2/transport_test.go 499 // (for HTTP/2). 500 func TestH12_HandlerWritesTooMuch(t *testing.T) { 501 h12Compare{ 502 Handler: func(w ResponseWriter, r *Request) { 503 w.Header().Set("Content-Length", "3") 504 w.(Flusher).Flush() 505 io.WriteString(w, "123") 506 w.(Flusher).Flush() 507 n, err := io.WriteString(w, "x") // too many 508 if n > 0 || err == nil { 509 t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err) 510 } 511 }, 512 }.run(t) 513 } 514 515 // Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip. 516 // Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298 517 func TestH12_AutoGzip(t *testing.T) { 518 h12Compare{ 519 Handler: func(w ResponseWriter, r *Request) { 520 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" { 521 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae) 522 } 523 w.Header().Set("Content-Encoding", "gzip") 524 gz := gzip.NewWriter(w) 525 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.") 526 gz.Close() 527 }, 528 }.run(t) 529 } 530 531 func TestH12_AutoGzip_Disabled(t *testing.T) { 532 h12Compare{ 533 Opts: []any{ 534 func(tr *Transport) { tr.DisableCompression = true }, 535 }, 536 Handler: func(w ResponseWriter, r *Request) { 537 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"]) 538 if ae := r.Header.Get("Accept-Encoding"); ae != "" { 539 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae) 540 } 541 }, 542 }.run(t) 543 } 544 545 // Test304Responses verifies that 304s don't declare that they're 546 // chunking in their response headers and aren't allowed to produce 547 // output. 548 func Test304Responses(t *testing.T) { run(t, test304Responses) } 549 func test304Responses(t *testing.T, mode testMode) { 550 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 551 w.WriteHeader(StatusNotModified) 552 _, err := w.Write([]byte("illegal body")) 553 if err != ErrBodyNotAllowed { 554 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) 555 } 556 })) 557 defer cst.close() 558 res, err := cst.c.Get(cst.ts.URL) 559 if err != nil { 560 t.Fatal(err) 561 } 562 if len(res.TransferEncoding) > 0 { 563 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) 564 } 565 body, err := io.ReadAll(res.Body) 566 if err != nil { 567 t.Error(err) 568 } 569 if len(body) > 0 { 570 t.Errorf("got unexpected body %q", string(body)) 571 } 572 } 573 574 func TestH12_ServerEmptyContentLength(t *testing.T) { 575 h12Compare{ 576 Handler: func(w ResponseWriter, r *Request) { 577 w.Header()["Content-Type"] = []string{""} 578 io.WriteString(w, "<html><body>hi</body></html>") 579 }, 580 }.run(t) 581 } 582 583 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) { 584 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4) 585 } 586 587 func TestH12_RequestContentLength_Known_Zero(t *testing.T) { 588 h12requestContentLength(t, func() io.Reader { return nil }, 0) 589 } 590 591 func TestH12_RequestContentLength_Unknown(t *testing.T) { 592 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1) 593 } 594 595 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) { 596 h12Compare{ 597 Handler: func(w ResponseWriter, r *Request) { 598 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength)) 599 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength) 600 }, 601 ReqFunc: func(c *Client, url string) (*Response, error) { 602 return c.Post(url, "text/plain", bodyfn()) 603 }, 604 CheckResponse: func(proto string, res *Response) { 605 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want { 606 t.Errorf("Proto %q got length %q; want %q", proto, got, want) 607 } 608 }, 609 }.run(t) 610 } 611 612 // Tests that closing the Request.Cancel channel also while still 613 // reading the response body. Issue 13159. 614 func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) } 615 func testCancelRequestMidBody(t *testing.T, mode testMode) { 616 unblock := make(chan bool) 617 didFlush := make(chan bool, 1) 618 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 619 io.WriteString(w, "Hello") 620 w.(Flusher).Flush() 621 didFlush <- true 622 <-unblock 623 io.WriteString(w, ", world.") 624 })) 625 defer close(unblock) 626 627 req, _ := NewRequest("GET", cst.ts.URL, nil) 628 cancel := make(chan struct{}) 629 req.Cancel = cancel 630 631 res, err := cst.c.Do(req) 632 if err != nil { 633 t.Fatal(err) 634 } 635 defer res.Body.Close() 636 <-didFlush 637 638 // Read a bit before we cancel. (Issue 13626) 639 // We should have "Hello" at least sitting there. 640 firstRead := make([]byte, 10) 641 n, err := res.Body.Read(firstRead) 642 if err != nil { 643 t.Fatal(err) 644 } 645 firstRead = firstRead[:n] 646 647 close(cancel) 648 649 rest, err := io.ReadAll(res.Body) 650 all := string(firstRead) + string(rest) 651 if all != "Hello" { 652 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest) 653 } 654 if err != ExportErrRequestCanceled { 655 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled) 656 } 657 } 658 659 // Tests that clients can send trailers to a server and that the server can read them. 660 func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) } 661 func testTrailersClientToServer(t *testing.T, mode testMode) { 662 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 663 var decl []string 664 for k := range r.Trailer { 665 decl = append(decl, k) 666 } 667 sort.Strings(decl) 668 669 slurp, err := io.ReadAll(r.Body) 670 if err != nil { 671 t.Errorf("Server reading request body: %v", err) 672 } 673 if string(slurp) != "foo" { 674 t.Errorf("Server read request body %q; want foo", slurp) 675 } 676 if r.Trailer == nil { 677 io.WriteString(w, "nil Trailer") 678 } else { 679 fmt.Fprintf(w, "decl: %v, vals: %s, %s", 680 decl, 681 r.Trailer.Get("Client-Trailer-A"), 682 r.Trailer.Get("Client-Trailer-B")) 683 } 684 })) 685 686 var req *Request 687 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader( 688 eofReaderFunc(func() { 689 req.Trailer["Client-Trailer-A"] = []string{"valuea"} 690 }), 691 strings.NewReader("foo"), 692 eofReaderFunc(func() { 693 req.Trailer["Client-Trailer-B"] = []string{"valueb"} 694 }), 695 )) 696 req.Trailer = Header{ 697 "Client-Trailer-A": nil, // to be set later 698 "Client-Trailer-B": nil, // to be set later 699 } 700 req.ContentLength = -1 701 res, err := cst.c.Do(req) 702 if err != nil { 703 t.Fatal(err) 704 } 705 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil { 706 t.Error(err) 707 } 708 } 709 710 // Tests that servers send trailers to a client and that the client can read them. 711 func TestTrailersServerToClient(t *testing.T) { 712 run(t, func(t *testing.T, mode testMode) { 713 testTrailersServerToClient(t, mode, false) 714 }) 715 } 716 func TestTrailersServerToClientFlush(t *testing.T) { 717 run(t, func(t *testing.T, mode testMode) { 718 testTrailersServerToClient(t, mode, true) 719 }) 720 } 721 722 func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) { 723 const body = "Some body" 724 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 725 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") 726 w.Header().Add("Trailer", "Server-Trailer-C") 727 728 io.WriteString(w, body) 729 if flush { 730 w.(Flusher).Flush() 731 } 732 733 // How handlers set Trailers: declare it ahead of time 734 // with the Trailer header, and then mutate the 735 // Header() of those values later, after the response 736 // has been written (we wrote to w above). 737 w.Header().Set("Server-Trailer-A", "valuea") 738 w.Header().Set("Server-Trailer-C", "valuec") // skipping B 739 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted") 740 })) 741 742 res, err := cst.c.Get(cst.ts.URL) 743 if err != nil { 744 t.Fatal(err) 745 } 746 747 wantHeader := Header{ 748 "Content-Type": {"text/plain; charset=utf-8"}, 749 } 750 wantLen := -1 751 if mode == http2Mode && !flush { 752 // In HTTP/1.1, any use of trailers forces HTTP/1.1 753 // chunking and a flush at the first write. That's 754 // unnecessary with HTTP/2's framing, so the server 755 // is able to calculate the length while still sending 756 // trailers afterwards. 757 wantLen = len(body) 758 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)} 759 } 760 if res.ContentLength != int64(wantLen) { 761 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen) 762 } 763 764 delete(res.Header, "Date") // irrelevant for test 765 if !reflect.DeepEqual(res.Header, wantHeader) { 766 t.Errorf("Header = %v; want %v", res.Header, wantHeader) 767 } 768 769 if got, want := res.Trailer, (Header{ 770 "Server-Trailer-A": nil, 771 "Server-Trailer-B": nil, 772 "Server-Trailer-C": nil, 773 }); !reflect.DeepEqual(got, want) { 774 t.Errorf("Trailer before body read = %v; want %v", got, want) 775 } 776 777 if err := wantBody(res, nil, body); err != nil { 778 t.Fatal(err) 779 } 780 781 if got, want := res.Trailer, (Header{ 782 "Server-Trailer-A": {"valuea"}, 783 "Server-Trailer-B": nil, 784 "Server-Trailer-C": {"valuec"}, 785 }); !reflect.DeepEqual(got, want) { 786 t.Errorf("Trailer after body read = %v; want %v", got, want) 787 } 788 } 789 790 // Don't allow a Body.Read after Body.Close. Issue 13648. 791 func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) } 792 func testResponseBodyReadAfterClose(t *testing.T, mode testMode) { 793 const body = "Some body" 794 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 795 io.WriteString(w, body) 796 })) 797 res, err := cst.c.Get(cst.ts.URL) 798 if err != nil { 799 t.Fatal(err) 800 } 801 res.Body.Close() 802 data, err := io.ReadAll(res.Body) 803 if len(data) != 0 || err == nil { 804 t.Fatalf("ReadAll returned %q, %v; want error", data, err) 805 } 806 } 807 808 func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) } 809 func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) { 810 const reqBody = "some request body" 811 const resBody = "some response body" 812 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 813 var wg sync.WaitGroup 814 wg.Add(2) 815 didRead := make(chan bool, 1) 816 // Read in one goroutine. 817 go func() { 818 defer wg.Done() 819 data, err := io.ReadAll(r.Body) 820 if string(data) != reqBody { 821 t.Errorf("Handler read %q; want %q", data, reqBody) 822 } 823 if err != nil { 824 t.Errorf("Handler Read: %v", err) 825 } 826 didRead <- true 827 }() 828 // Write in another goroutine. 829 go func() { 830 defer wg.Done() 831 if mode != http2Mode { 832 // our HTTP/1 implementation intentionally 833 // doesn't permit writes during read (mostly 834 // due to it being undefined); if that is ever 835 // relaxed, change this. 836 <-didRead 837 } 838 io.WriteString(w, resBody) 839 }() 840 wg.Wait() 841 })) 842 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody)) 843 req.Header.Add("Expect", "100-continue") // just to complicate things 844 res, err := cst.c.Do(req) 845 if err != nil { 846 t.Fatal(err) 847 } 848 data, err := io.ReadAll(res.Body) 849 defer res.Body.Close() 850 if err != nil { 851 t.Fatal(err) 852 } 853 if string(data) != resBody { 854 t.Errorf("read %q; want %q", data, resBody) 855 } 856 } 857 858 func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) } 859 func testConnectRequest(t *testing.T, mode testMode) { 860 gotc := make(chan *Request, 1) 861 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 862 gotc <- r 863 })) 864 865 u, err := url.Parse(cst.ts.URL) 866 if err != nil { 867 t.Fatal(err) 868 } 869 870 tests := []struct { 871 req *Request 872 want string 873 }{ 874 { 875 req: &Request{ 876 Method: "CONNECT", 877 Header: Header{}, 878 URL: u, 879 }, 880 want: u.Host, 881 }, 882 { 883 req: &Request{ 884 Method: "CONNECT", 885 Header: Header{}, 886 URL: u, 887 Host: "example.com:123", 888 }, 889 want: "example.com:123", 890 }, 891 } 892 893 for i, tt := range tests { 894 res, err := cst.c.Do(tt.req) 895 if err != nil { 896 t.Errorf("%d. RoundTrip = %v", i, err) 897 continue 898 } 899 res.Body.Close() 900 req := <-gotc 901 if req.Method != "CONNECT" { 902 t.Errorf("method = %q; want CONNECT", req.Method) 903 } 904 if req.Host != tt.want { 905 t.Errorf("Host = %q; want %q", req.Host, tt.want) 906 } 907 if req.URL.Host != tt.want { 908 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want) 909 } 910 } 911 } 912 913 func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) } 914 func testTransportUserAgent(t *testing.T, mode testMode) { 915 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 916 fmt.Fprintf(w, "%q", r.Header["User-Agent"]) 917 })) 918 919 either := func(a, b string) string { 920 if mode == http2Mode { 921 return b 922 } 923 return a 924 } 925 926 tests := []struct { 927 setup func(*Request) 928 want string 929 }{ 930 { 931 func(r *Request) {}, 932 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`), 933 }, 934 { 935 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") }, 936 `["foo/1.2.3"]`, 937 }, 938 { 939 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} }, 940 `["single"]`, 941 }, 942 { 943 func(r *Request) { r.Header.Set("User-Agent", "") }, 944 `[]`, 945 }, 946 { 947 func(r *Request) { r.Header["User-Agent"] = nil }, 948 `[]`, 949 }, 950 } 951 for i, tt := range tests { 952 req, _ := NewRequest("GET", cst.ts.URL, nil) 953 tt.setup(req) 954 res, err := cst.c.Do(req) 955 if err != nil { 956 t.Errorf("%d. RoundTrip = %v", i, err) 957 continue 958 } 959 slurp, err := io.ReadAll(res.Body) 960 res.Body.Close() 961 if err != nil { 962 t.Errorf("%d. read body = %v", i, err) 963 continue 964 } 965 if string(slurp) != tt.want { 966 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want) 967 } 968 } 969 } 970 971 func TestStarRequestMethod(t *testing.T) { 972 for _, method := range []string{"FOO", "OPTIONS"} { 973 t.Run(method, func(t *testing.T) { 974 run(t, func(t *testing.T, mode testMode) { 975 testStarRequest(t, method, mode) 976 }) 977 }) 978 } 979 } 980 func testStarRequest(t *testing.T, method string, mode testMode) { 981 gotc := make(chan *Request, 1) 982 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 983 w.Header().Set("foo", "bar") 984 gotc <- r 985 w.(Flusher).Flush() 986 })) 987 988 u, err := url.Parse(cst.ts.URL) 989 if err != nil { 990 t.Fatal(err) 991 } 992 u.Path = "*" 993 994 req := &Request{ 995 Method: method, 996 Header: Header{}, 997 URL: u, 998 } 999 1000 res, err := cst.c.Do(req) 1001 if err != nil { 1002 t.Fatalf("RoundTrip = %v", err) 1003 } 1004 res.Body.Close() 1005 1006 wantFoo := "bar" 1007 wantLen := int64(-1) 1008 if method == "OPTIONS" { 1009 wantFoo = "" 1010 wantLen = 0 1011 } 1012 if res.StatusCode != 200 { 1013 t.Errorf("status code = %v; want %d", res.Status, 200) 1014 } 1015 if res.ContentLength != wantLen { 1016 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen) 1017 } 1018 if got := res.Header.Get("foo"); got != wantFoo { 1019 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo) 1020 } 1021 select { 1022 case req = <-gotc: 1023 default: 1024 req = nil 1025 } 1026 if req == nil { 1027 if method != "OPTIONS" { 1028 t.Fatalf("handler never got request") 1029 } 1030 return 1031 } 1032 if req.Method != method { 1033 t.Errorf("method = %q; want %q", req.Method, method) 1034 } 1035 if req.URL.Path != "*" { 1036 t.Errorf("URL.Path = %q; want *", req.URL.Path) 1037 } 1038 if req.RequestURI != "*" { 1039 t.Errorf("RequestURI = %q; want *", req.RequestURI) 1040 } 1041 } 1042 1043 // Issue 13957 1044 func TestTransportDiscardsUnneededConns(t *testing.T) { 1045 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode}) 1046 } 1047 func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) { 1048 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1049 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) 1050 })) 1051 defer cst.close() 1052 1053 var numOpen, numClose int32 // atomic 1054 1055 tlsConfig := &tls.Config{InsecureSkipVerify: true} 1056 tr := &Transport{ 1057 TLSClientConfig: tlsConfig, 1058 DialTLS: func(_, addr string) (net.Conn, error) { 1059 time.Sleep(10 * time.Millisecond) 1060 rc, err := net.Dial("tcp", addr) 1061 if err != nil { 1062 return nil, err 1063 } 1064 atomic.AddInt32(&numOpen, 1) 1065 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }} 1066 return tls.Client(c, tlsConfig), nil 1067 }, 1068 } 1069 if err := ExportHttp2ConfigureTransport(tr); err != nil { 1070 t.Fatal(err) 1071 } 1072 defer tr.CloseIdleConnections() 1073 1074 c := &Client{Transport: tr} 1075 1076 const N = 10 1077 gotBody := make(chan string, N) 1078 var wg sync.WaitGroup 1079 for i := 0; i < N; i++ { 1080 wg.Add(1) 1081 go func() { 1082 defer wg.Done() 1083 resp, err := c.Get(cst.ts.URL) 1084 if err != nil { 1085 // Try to work around spurious connection reset on loaded system. 1086 // See golang.org/issue/33585 and golang.org/issue/36797. 1087 time.Sleep(10 * time.Millisecond) 1088 resp, err = c.Get(cst.ts.URL) 1089 if err != nil { 1090 t.Errorf("Get: %v", err) 1091 return 1092 } 1093 } 1094 defer resp.Body.Close() 1095 slurp, err := io.ReadAll(resp.Body) 1096 if err != nil { 1097 t.Error(err) 1098 } 1099 gotBody <- string(slurp) 1100 }() 1101 } 1102 wg.Wait() 1103 close(gotBody) 1104 1105 var last string 1106 for got := range gotBody { 1107 if last == "" { 1108 last = got 1109 continue 1110 } 1111 if got != last { 1112 t.Errorf("Response body changed: %q -> %q", last, got) 1113 } 1114 } 1115 1116 var open, close int32 1117 for i := 0; i < 150; i++ { 1118 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose) 1119 if open < 1 { 1120 t.Fatalf("open = %d; want at least", open) 1121 } 1122 if close == open-1 { 1123 // Success 1124 return 1125 } 1126 time.Sleep(10 * time.Millisecond) 1127 } 1128 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1) 1129 } 1130 1131 // tests that Transport doesn't retain a pointer to the provided request. 1132 func TestTransportGCRequest(t *testing.T) { 1133 run(t, func(t *testing.T, mode testMode) { 1134 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) }) 1135 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) }) 1136 }) 1137 } 1138 func testTransportGCRequest(t *testing.T, mode testMode, body bool) { 1139 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1140 io.ReadAll(r.Body) 1141 if body { 1142 io.WriteString(w, "Hello.") 1143 } 1144 })) 1145 1146 didGC := make(chan struct{}) 1147 (func() { 1148 body := strings.NewReader("some body") 1149 req, _ := NewRequest("POST", cst.ts.URL, body) 1150 runtime.SetFinalizer(req, func(*Request) { close(didGC) }) 1151 res, err := cst.c.Do(req) 1152 if err != nil { 1153 t.Fatal(err) 1154 } 1155 if _, err := io.ReadAll(res.Body); err != nil { 1156 t.Fatal(err) 1157 } 1158 if err := res.Body.Close(); err != nil { 1159 t.Fatal(err) 1160 } 1161 })() 1162 timeout := time.NewTimer(5 * time.Second) 1163 defer timeout.Stop() 1164 for { 1165 select { 1166 case <-didGC: 1167 return 1168 case <-time.After(100 * time.Millisecond): 1169 runtime.GC() 1170 case <-timeout.C: 1171 t.Fatal("never saw GC of request") 1172 } 1173 } 1174 } 1175 1176 func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) } 1177 func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) { 1178 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1179 fmt.Fprintf(w, "Handler saw headers: %q", r.Header) 1180 }), optQuietLog) 1181 cst.tr.DisableKeepAlives = true 1182 1183 tests := []struct { 1184 key, val string 1185 ok bool 1186 }{ 1187 {"Foo", "capital-key", true}, // verify h2 allows capital keys 1188 {"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed 1189 {"Foo", "two\nlines", false}, // \n byte in value not allowed 1190 {"bogus\nkey", "v", false}, // \n byte also not allowed in key 1191 {"A space", "v", false}, // spaces in keys not allowed 1192 {"имя", "v", false}, // key must be ascii 1193 {"name", "валю", true}, // value may be non-ascii 1194 {"", "v", false}, // key must be non-empty 1195 {"k", "", true}, // value may be empty 1196 } 1197 for _, tt := range tests { 1198 dialedc := make(chan bool, 1) 1199 cst.tr.Dial = func(netw, addr string) (net.Conn, error) { 1200 dialedc <- true 1201 return net.Dial(netw, addr) 1202 } 1203 req, _ := NewRequest("GET", cst.ts.URL, nil) 1204 req.Header[tt.key] = []string{tt.val} 1205 res, err := cst.c.Do(req) 1206 var body []byte 1207 if err == nil { 1208 body, _ = io.ReadAll(res.Body) 1209 res.Body.Close() 1210 } 1211 var dialed bool 1212 select { 1213 case <-dialedc: 1214 dialed = true 1215 default: 1216 } 1217 1218 if !tt.ok && dialed { 1219 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body) 1220 } else if (err == nil) != tt.ok { 1221 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok) 1222 } 1223 } 1224 } 1225 1226 func TestInterruptWithPanic(t *testing.T) { 1227 run(t, func(t *testing.T, mode testMode) { 1228 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") }) 1229 t.Run("nil", func(t *testing.T) { testInterruptWithPanic(t, mode, nil) }) 1230 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) }) 1231 }) 1232 } 1233 func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) { 1234 const msg = "hello" 1235 1236 testDone := make(chan struct{}) 1237 defer close(testDone) 1238 1239 var errorLog lockedBytesBuffer 1240 gotHeaders := make(chan bool, 1) 1241 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1242 io.WriteString(w, msg) 1243 w.(Flusher).Flush() 1244 1245 select { 1246 case <-gotHeaders: 1247 case <-testDone: 1248 } 1249 panic(panicValue) 1250 }), func(ts *httptest.Server) { 1251 ts.Config.ErrorLog = log.New(&errorLog, "", 0) 1252 }) 1253 res, err := cst.c.Get(cst.ts.URL) 1254 if err != nil { 1255 t.Fatal(err) 1256 } 1257 gotHeaders <- true 1258 defer res.Body.Close() 1259 slurp, err := io.ReadAll(res.Body) 1260 if string(slurp) != msg { 1261 t.Errorf("client read %q; want %q", slurp, msg) 1262 } 1263 if err == nil { 1264 t.Errorf("client read all successfully; want some error") 1265 } 1266 logOutput := func() string { 1267 errorLog.Lock() 1268 defer errorLog.Unlock() 1269 return errorLog.String() 1270 } 1271 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler 1272 1273 if err := waitErrCondition(5*time.Second, 10*time.Millisecond, func() error { 1274 gotLog := logOutput() 1275 if !wantStackLogged { 1276 if gotLog == "" { 1277 return nil 1278 } 1279 return fmt.Errorf("want no log output; got: %s", gotLog) 1280 } 1281 if gotLog == "" { 1282 return fmt.Errorf("wanted a stack trace logged; got nothing") 1283 } 1284 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 { 1285 return fmt.Errorf("output doesn't look like a panic stack trace. Got: %s", gotLog) 1286 } 1287 return nil 1288 }); err != nil { 1289 t.Fatal(err) 1290 } 1291 } 1292 1293 type lockedBytesBuffer struct { 1294 sync.Mutex 1295 bytes.Buffer 1296 } 1297 1298 func (b *lockedBytesBuffer) Write(p []byte) (int, error) { 1299 b.Lock() 1300 defer b.Unlock() 1301 return b.Buffer.Write(p) 1302 } 1303 1304 // Issue 15366 1305 func TestH12_AutoGzipWithDumpResponse(t *testing.T) { 1306 h12Compare{ 1307 Handler: func(w ResponseWriter, r *Request) { 1308 h := w.Header() 1309 h.Set("Content-Encoding", "gzip") 1310 h.Set("Content-Length", "23") 1311 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00") 1312 }, 1313 EarlyCheckResponse: func(proto string, res *Response) { 1314 if !res.Uncompressed { 1315 t.Errorf("%s: expected Uncompressed to be set", proto) 1316 } 1317 dump, err := httputil.DumpResponse(res, true) 1318 if err != nil { 1319 t.Errorf("%s: DumpResponse: %v", proto, err) 1320 return 1321 } 1322 if strings.Contains(string(dump), "Connection: close") { 1323 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump) 1324 } 1325 if !strings.Contains(string(dump), "FOO") { 1326 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump) 1327 } 1328 }, 1329 }.run(t) 1330 } 1331 1332 // Issue 14607 1333 func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) } 1334 func testCloseIdleConnections(t *testing.T, mode testMode) { 1335 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1336 w.Header().Set("X-Addr", r.RemoteAddr) 1337 })) 1338 get := func() string { 1339 res, err := cst.c.Get(cst.ts.URL) 1340 if err != nil { 1341 t.Fatal(err) 1342 } 1343 res.Body.Close() 1344 v := res.Header.Get("X-Addr") 1345 if v == "" { 1346 t.Fatal("didn't get X-Addr") 1347 } 1348 return v 1349 } 1350 a1 := get() 1351 cst.tr.CloseIdleConnections() 1352 a2 := get() 1353 if a1 == a2 { 1354 t.Errorf("didn't close connection") 1355 } 1356 } 1357 1358 type noteCloseConn struct { 1359 net.Conn 1360 closeFunc func() 1361 } 1362 1363 func (x noteCloseConn) Close() error { 1364 x.closeFunc() 1365 return x.Conn.Close() 1366 } 1367 1368 type testErrorReader struct{ t *testing.T } 1369 1370 func (r testErrorReader) Read(p []byte) (n int, err error) { 1371 r.t.Error("unexpected Read call") 1372 return 0, io.EOF 1373 } 1374 1375 func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) } 1376 func testNoSniffExpectRequestBody(t *testing.T, mode testMode) { 1377 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1378 w.WriteHeader(StatusUnauthorized) 1379 })) 1380 1381 // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. 1382 cst.tr.ExpectContinueTimeout = 10 * time.Second 1383 1384 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t}) 1385 if err != nil { 1386 t.Fatal(err) 1387 } 1388 req.ContentLength = 0 // so transport is tempted to sniff it 1389 req.Header.Set("Expect", "100-continue") 1390 res, err := cst.tr.RoundTrip(req) 1391 if err != nil { 1392 t.Fatal(err) 1393 } 1394 defer res.Body.Close() 1395 if res.StatusCode != StatusUnauthorized { 1396 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized) 1397 } 1398 } 1399 1400 func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) } 1401 func testServerUndeclaredTrailers(t *testing.T, mode testMode) { 1402 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1403 w.Header().Set("Foo", "Bar") 1404 w.Header().Set("Trailer:Foo", "Baz") 1405 w.(Flusher).Flush() 1406 w.Header().Add("Trailer:Foo", "Baz2") 1407 w.Header().Set("Trailer:Bar", "Quux") 1408 })) 1409 res, err := cst.c.Get(cst.ts.URL) 1410 if err != nil { 1411 t.Fatal(err) 1412 } 1413 if _, err := io.Copy(io.Discard, res.Body); err != nil { 1414 t.Fatal(err) 1415 } 1416 res.Body.Close() 1417 delete(res.Header, "Date") 1418 delete(res.Header, "Content-Type") 1419 1420 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) { 1421 t.Errorf("Header = %#v; want %#v", res.Header, want) 1422 } 1423 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) { 1424 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want) 1425 } 1426 } 1427 1428 func TestBadResponseAfterReadingBody(t *testing.T) { 1429 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode}) 1430 } 1431 func testBadResponseAfterReadingBody(t *testing.T, mode testMode) { 1432 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1433 _, err := io.Copy(io.Discard, r.Body) 1434 if err != nil { 1435 t.Fatal(err) 1436 } 1437 c, _, err := w.(Hijacker).Hijack() 1438 if err != nil { 1439 t.Fatal(err) 1440 } 1441 defer c.Close() 1442 fmt.Fprintln(c, "some bogus crap") 1443 })) 1444 1445 closes := 0 1446 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) 1447 if err == nil { 1448 res.Body.Close() 1449 t.Fatal("expected an error to be returned from Post") 1450 } 1451 if closes != 1 { 1452 t.Errorf("closes = %d; want 1", closes) 1453 } 1454 } 1455 1456 func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) } 1457 func testWriteHeader0(t *testing.T, mode testMode) { 1458 gotpanic := make(chan bool, 1) 1459 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1460 defer close(gotpanic) 1461 defer func() { 1462 if e := recover(); e != nil { 1463 got := fmt.Sprintf("%T, %v", e, e) 1464 want := "string, invalid WriteHeader code 0" 1465 if got != want { 1466 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want) 1467 } 1468 gotpanic <- true 1469 1470 // Set an explicit 503. This also tests that the WriteHeader call panics 1471 // before it recorded that an explicit value was set and that bogus 1472 // value wasn't stuck. 1473 w.WriteHeader(503) 1474 } 1475 }() 1476 w.WriteHeader(0) 1477 })) 1478 res, err := cst.c.Get(cst.ts.URL) 1479 if err != nil { 1480 t.Fatal(err) 1481 } 1482 if res.StatusCode != 503 { 1483 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status) 1484 } 1485 if !<-gotpanic { 1486 t.Error("expected panic in handler") 1487 } 1488 } 1489 1490 // Issue 23010: don't be super strict checking WriteHeader's code if 1491 // it's not even valid to call WriteHeader then anyway. 1492 func TestWriteHeaderNoCodeCheck(t *testing.T) { 1493 run(t, func(t *testing.T, mode testMode) { 1494 testWriteHeaderAfterWrite(t, mode, false) 1495 }) 1496 } 1497 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { 1498 testWriteHeaderAfterWrite(t, http1Mode, true) 1499 } 1500 func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) { 1501 var errorLog lockedBytesBuffer 1502 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1503 if hijack { 1504 conn, _, _ := w.(Hijacker).Hijack() 1505 defer conn.Close() 1506 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo")) 1507 w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 1508 conn.Write([]byte("bar")) 1509 return 1510 } 1511 io.WriteString(w, "foo") 1512 w.(Flusher).Flush() 1513 w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 1514 io.WriteString(w, "bar") 1515 }), func(ts *httptest.Server) { 1516 ts.Config.ErrorLog = log.New(&errorLog, "", 0) 1517 }) 1518 res, err := cst.c.Get(cst.ts.URL) 1519 if err != nil { 1520 t.Fatal(err) 1521 } 1522 defer res.Body.Close() 1523 body, err := io.ReadAll(res.Body) 1524 if err != nil { 1525 t.Fatal(err) 1526 } 1527 if got, want := string(body), "foobar"; got != want { 1528 t.Errorf("got = %q; want %q", got, want) 1529 } 1530 1531 // Also check the stderr output: 1532 if mode == http2Mode { 1533 // TODO: also emit this log message for HTTP/2? 1534 // We historically haven't, so don't check. 1535 return 1536 } 1537 gotLog := strings.TrimSpace(errorLog.String()) 1538 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:" 1539 if hijack { 1540 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:" 1541 } 1542 if !strings.HasPrefix(gotLog, wantLog) { 1543 t.Errorf("stderr output = %q; want %q", gotLog, wantLog) 1544 } 1545 } 1546 1547 func TestBidiStreamReverseProxy(t *testing.T) { 1548 run(t, testBidiStreamReverseProxy, []testMode{http2Mode}) 1549 } 1550 func testBidiStreamReverseProxy(t *testing.T, mode testMode) { 1551 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1552 if _, err := io.Copy(w, r.Body); err != nil { 1553 log.Printf("bidi backend copy: %v", err) 1554 } 1555 })) 1556 1557 backURL, err := url.Parse(backend.ts.URL) 1558 if err != nil { 1559 t.Fatal(err) 1560 } 1561 rp := httputil.NewSingleHostReverseProxy(backURL) 1562 rp.Transport = backend.tr 1563 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1564 rp.ServeHTTP(w, r) 1565 })) 1566 1567 bodyRes := make(chan any, 1) // error or hash.Hash 1568 pr, pw := io.Pipe() 1569 req, _ := NewRequest("PUT", proxy.ts.URL, pr) 1570 const size = 4 << 20 1571 go func() { 1572 h := sha1.New() 1573 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size) 1574 go pw.Close() 1575 if err != nil { 1576 bodyRes <- err 1577 } else { 1578 bodyRes <- h 1579 } 1580 }() 1581 res, err := backend.c.Do(req) 1582 if err != nil { 1583 t.Fatal(err) 1584 } 1585 defer res.Body.Close() 1586 hgot := sha1.New() 1587 n, err := io.Copy(hgot, res.Body) 1588 if err != nil { 1589 t.Fatal(err) 1590 } 1591 if n != size { 1592 t.Fatalf("got %d bytes; want %d", n, size) 1593 } 1594 select { 1595 case v := <-bodyRes: 1596 switch v := v.(type) { 1597 default: 1598 t.Fatalf("body copy: %v", err) 1599 case hash.Hash: 1600 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) { 1601 t.Errorf("written bytes didn't match received bytes") 1602 } 1603 } 1604 case <-time.After(10 * time.Second): 1605 t.Fatal("timeout") 1606 } 1607 1608 } 1609 1610 // Always use HTTP/1.1 for WebSocket upgrades. 1611 func TestH12_WebSocketUpgrade(t *testing.T) { 1612 h12Compare{ 1613 Handler: func(w ResponseWriter, r *Request) { 1614 h := w.Header() 1615 h.Set("Foo", "bar") 1616 }, 1617 ReqFunc: func(c *Client, url string) (*Response, error) { 1618 req, _ := NewRequest("GET", url, nil) 1619 req.Header.Set("Connection", "Upgrade") 1620 req.Header.Set("Upgrade", "WebSocket") 1621 return c.Do(req) 1622 }, 1623 EarlyCheckResponse: func(proto string, res *Response) { 1624 if res.Proto != "HTTP/1.1" { 1625 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto) 1626 } 1627 res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0 1628 }, 1629 }.run(t) 1630 } 1631 1632 func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) } 1633 func testIdentityTransferEncoding(t *testing.T, mode testMode) { 1634 const body = "body" 1635 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1636 gotBody, _ := io.ReadAll(r.Body) 1637 if got, want := string(gotBody), body; got != want { 1638 t.Errorf("got request body = %q; want %q", got, want) 1639 } 1640 w.Header().Set("Transfer-Encoding", "identity") 1641 w.WriteHeader(StatusOK) 1642 w.(Flusher).Flush() 1643 io.WriteString(w, body) 1644 })) 1645 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body)) 1646 res, err := cst.c.Do(req) 1647 if err != nil { 1648 t.Fatal(err) 1649 } 1650 defer res.Body.Close() 1651 gotBody, err := io.ReadAll(res.Body) 1652 if err != nil { 1653 t.Fatal(err) 1654 } 1655 if got, want := string(gotBody), body; got != want { 1656 t.Errorf("got response body = %q; want %q", got, want) 1657 } 1658 } 1659 1660 func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) } 1661 func testEarlyHintsRequest(t *testing.T, mode testMode) { 1662 var wg sync.WaitGroup 1663 wg.Add(1) 1664 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1665 h := w.Header() 1666 1667 h.Add("Content-Length", "123") // must be ignored 1668 h.Add("Link", "</style.css>; rel=preload; as=style") 1669 h.Add("Link", "</script.js>; rel=preload; as=script") 1670 w.WriteHeader(StatusEarlyHints) 1671 1672 wg.Wait() 1673 1674 h.Add("Link", "</foo.js>; rel=preload; as=script") 1675 w.WriteHeader(StatusEarlyHints) 1676 1677 w.Write([]byte("Hello")) 1678 })) 1679 1680 checkLinkHeaders := func(t *testing.T, expected, got []string) { 1681 t.Helper() 1682 1683 if len(expected) != len(got) { 1684 t.Errorf("got %d expected %d", len(got), len(expected)) 1685 } 1686 1687 for i := range expected { 1688 if expected[i] != got[i] { 1689 t.Errorf("got %q expected %q", got[i], expected[i]) 1690 } 1691 } 1692 } 1693 1694 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) { 1695 t.Helper() 1696 1697 for _, h := range []string{"Content-Length", "Transfer-Encoding"} { 1698 if v, ok := header[h]; ok { 1699 t.Errorf("%s is %q; must not be sent", h, v) 1700 } 1701 } 1702 } 1703 1704 var respCounter uint8 1705 trace := &httptrace.ClientTrace{ 1706 Got1xxResponse: func(code int, header textproto.MIMEHeader) error { 1707 switch respCounter { 1708 case 0: 1709 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"]) 1710 checkExcludedHeaders(t, header) 1711 1712 wg.Done() 1713 case 1: 1714 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"]) 1715 checkExcludedHeaders(t, header) 1716 1717 default: 1718 t.Error("Unexpected 1xx response") 1719 } 1720 1721 respCounter++ 1722 1723 return nil 1724 }, 1725 } 1726 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil) 1727 1728 res, err := cst.c.Do(req) 1729 if err != nil { 1730 t.Fatal(err) 1731 } 1732 defer res.Body.Close() 1733 1734 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"]) 1735 if cl := res.Header.Get("Content-Length"); cl != "123" { 1736 t.Errorf("Content-Length is %q; want 123", cl) 1737 } 1738 1739 body, _ := io.ReadAll(res.Body) 1740 if string(body) != "Hello" { 1741 t.Errorf("Read body %q; want Hello", body) 1742 } 1743 }