github.com/ooni/oohttp@v0.7.2/clientserver_test.go (about) 1 // Copyright 2015 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode. 6 7 package http_test 8 9 import ( 10 "bytes" 11 "compress/gzip" 12 "context" 13 "crypto/rand" 14 "crypto/sha1" 15 "crypto/tls" 16 "fmt" 17 "hash" 18 "io" 19 "log" 20 "net" 21 "net/textproto" 22 "net/url" 23 "os" 24 "reflect" 25 "runtime" 26 "sort" 27 "strings" 28 "sync" 29 "sync/atomic" 30 "testing" 31 "time" 32 33 . "github.com/ooni/oohttp" 34 httptest "github.com/ooni/oohttp/httptest" 35 httptrace "github.com/ooni/oohttp/httptrace" 36 httputil "github.com/ooni/oohttp/httputil" 37 ) 38 39 type testMode string 40 41 const ( 42 http1Mode = testMode("h1") // HTTP/1.1 43 https1Mode = testMode("https1") // HTTPS/1.1 44 http2Mode = testMode("h2") // HTTP/2 45 ) 46 47 type testNotParallelOpt struct{} 48 49 var ( 50 testNotParallel = testNotParallelOpt{} 51 ) 52 53 type TBRun[T any] interface { 54 testing.TB 55 Run(string, func(T)) bool 56 } 57 58 // run runs a client/server test in a variety of test configurations. 59 // 60 // Tests execute in HTTP/1.1 and HTTP/2 modes by default. 61 // To run in a different set of configurations, pass a []testMode option. 62 // 63 // Tests call t.Parallel() by default. 64 // To disable parallel execution, pass the testNotParallel option. 65 func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) { 66 t.Helper() 67 modes := []testMode{http1Mode, http2Mode} 68 parallel := true 69 for _, opt := range opts { 70 switch opt := opt.(type) { 71 case []testMode: 72 modes = opt 73 case testNotParallelOpt: 74 parallel = false 75 default: 76 t.Fatalf("unknown option type %T", opt) 77 } 78 } 79 if t, ok := any(t).(*testing.T); ok && parallel { 80 setParallel(t) 81 } 82 for _, mode := range modes { 83 t.Run(string(mode), func(t T) { 84 t.Helper() 85 if t, ok := any(t).(*testing.T); ok && parallel { 86 setParallel(t) 87 } 88 t.Cleanup(func() { 89 afterTest(t) 90 }) 91 f(t, mode) 92 }) 93 } 94 } 95 96 type clientServerTest struct { 97 t testing.TB 98 h2 bool 99 h Handler 100 ts *httptest.Server 101 tr *Transport 102 c *Client 103 } 104 105 func (t *clientServerTest) close() { 106 t.tr.CloseIdleConnections() 107 t.ts.Close() 108 } 109 110 func (t *clientServerTest) getURL(u string) string { 111 res, err := t.c.Get(u) 112 if err != nil { 113 t.t.Fatal(err) 114 } 115 defer res.Body.Close() 116 slurp, err := io.ReadAll(res.Body) 117 if err != nil { 118 t.t.Fatal(err) 119 } 120 return string(slurp) 121 } 122 123 func (t *clientServerTest) scheme() string { 124 if t.h2 { 125 return "https" 126 } 127 return "http" 128 } 129 130 var optQuietLog = func(ts *httptest.Server) { 131 ts.Config.ErrorLog = quietLog 132 } 133 134 func optWithServerLog(lg *log.Logger) func(*httptest.Server) { 135 return func(ts *httptest.Server) { 136 ts.Config.ErrorLog = lg 137 } 138 } 139 140 // newClientServerTest creates and starts an httptest.Server. 141 // 142 // The mode parameter selects the implementation to test: 143 // HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use 144 // the 'run' function, which will start a subtests for each tested mode. 145 // 146 // The vararg opts parameter can include functions to configure the 147 // test server or transport. 148 // 149 // func(*httptest.Server) // run before starting the server 150 // func(*http.Transport) 151 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest { 152 if mode == http2Mode { 153 CondSkipHTTP2(t) 154 } 155 cst := &clientServerTest{ 156 t: t, 157 h2: mode == http2Mode, 158 h: h, 159 } 160 cst.ts = httptest.NewUnstartedServer(h) 161 162 var transportFuncs []func(*Transport) 163 for _, opt := range opts { 164 switch opt := opt.(type) { 165 case func(*Transport): 166 transportFuncs = append(transportFuncs, opt) 167 case func(*httptest.Server): 168 opt(cst.ts) 169 default: 170 t.Fatalf("unhandled option type %T", opt) 171 } 172 } 173 174 if cst.ts.Config.ErrorLog == nil { 175 cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0) 176 } 177 178 switch mode { 179 case http1Mode: 180 cst.ts.Start() 181 case https1Mode: 182 cst.ts.StartTLS() 183 case http2Mode: 184 ExportHttp2ConfigureServer(cst.ts.Config, nil) 185 cst.ts.TLS = cst.ts.Config.TLSConfig 186 cst.ts.StartTLS() 187 default: 188 t.Fatalf("unknown test mode %v", mode) 189 } 190 cst.c = cst.ts.Client() 191 cst.tr = cst.c.Transport.(*Transport) 192 if mode == http2Mode { 193 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { 194 t.Fatal(err) 195 } 196 } 197 for _, f := range transportFuncs { 198 f(cst.tr) 199 } 200 t.Cleanup(func() { 201 cst.close() 202 }) 203 return cst 204 } 205 206 type testLogWriter struct { 207 t testing.TB 208 } 209 210 func (w testLogWriter) Write(b []byte) (int, error) { 211 w.t.Logf("server log: %v", strings.TrimSpace(string(b))) 212 return len(b), nil 213 } 214 215 // Testing the newClientServerTest helper itself. 216 func TestNewClientServerTest(t *testing.T) { 217 run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode}) 218 } 219 func testNewClientServerTest(t *testing.T, mode testMode) { 220 var got struct { 221 sync.Mutex 222 proto string 223 hasTLS bool 224 } 225 h := HandlerFunc(func(w ResponseWriter, r *Request) { 226 got.Lock() 227 defer got.Unlock() 228 got.proto = r.Proto 229 got.hasTLS = r.TLS != nil 230 }) 231 cst := newClientServerTest(t, mode, h) 232 if _, err := cst.c.Head(cst.ts.URL); err != nil { 233 t.Fatal(err) 234 } 235 var wantProto string 236 var wantTLS bool 237 switch mode { 238 case http1Mode: 239 wantProto = "HTTP/1.1" 240 wantTLS = false 241 case https1Mode: 242 wantProto = "HTTP/1.1" 243 wantTLS = true 244 case http2Mode: 245 wantProto = "HTTP/2.0" 246 wantTLS = true 247 } 248 if got.proto != wantProto { 249 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto) 250 } 251 if got.hasTLS != wantTLS { 252 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS) 253 } 254 } 255 256 func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) } 257 func testChunkedResponseHeaders(t *testing.T, mode testMode) { 258 log.SetOutput(io.Discard) // is noisy otherwise 259 defer log.SetOutput(os.Stderr) 260 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 261 w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted 262 w.(Flusher).Flush() 263 fmt.Fprintf(w, "I am a chunked response.") 264 })) 265 266 res, err := cst.c.Get(cst.ts.URL) 267 if err != nil { 268 t.Fatalf("Get error: %v", err) 269 } 270 defer res.Body.Close() 271 if g, e := res.ContentLength, int64(-1); g != e { 272 t.Errorf("expected ContentLength of %d; got %d", e, g) 273 } 274 wantTE := []string{"chunked"} 275 if mode == http2Mode { 276 wantTE = nil 277 } 278 if !reflect.DeepEqual(res.TransferEncoding, wantTE) { 279 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE) 280 } 281 if got, haveCL := res.Header["Content-Length"]; haveCL { 282 t.Errorf("Unexpected Content-Length: %q", got) 283 } 284 } 285 286 type reqFunc func(c *Client, url string) (*Response, error) 287 288 // h12Compare is a test that compares HTTP/1 and HTTP/2 behavior 289 // against each other. 290 type h12Compare struct { 291 Handler func(ResponseWriter, *Request) // required 292 ReqFunc reqFunc // optional 293 CheckResponse func(proto string, res *Response) // optional 294 EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize 295 Opts []any 296 } 297 298 func (tt h12Compare) reqFunc() reqFunc { 299 if tt.ReqFunc == nil { 300 return (*Client).Get 301 } 302 return tt.ReqFunc 303 } 304 305 func (tt h12Compare) run(t *testing.T) { 306 setParallel(t) 307 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...) 308 defer cst1.close() 309 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...) 310 defer cst2.close() 311 312 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) 313 if err != nil { 314 t.Errorf("HTTP/1 request: %v", err) 315 return 316 } 317 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL) 318 if err != nil { 319 t.Errorf("HTTP/2 request: %v", err) 320 return 321 } 322 323 if fn := tt.EarlyCheckResponse; fn != nil { 324 fn("HTTP/1.1", res1) 325 fn("HTTP/2.0", res2) 326 } 327 328 tt.normalizeRes(t, res1, "HTTP/1.1") 329 tt.normalizeRes(t, res2, "HTTP/2.0") 330 res1body, res2body := res1.Body, res2.Body 331 332 eres1 := mostlyCopy(res1) 333 eres2 := mostlyCopy(res2) 334 if !reflect.DeepEqual(eres1, eres2) { 335 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v", 336 cst1.ts.URL, eres1, cst2.ts.URL, eres2) 337 } 338 if !reflect.DeepEqual(res1body, res2body) { 339 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body) 340 } 341 if fn := tt.CheckResponse; fn != nil { 342 res1.Body, res2.Body = res1body, res2body 343 fn("HTTP/1.1", res1) 344 fn("HTTP/2.0", res2) 345 } 346 } 347 348 func mostlyCopy(r *Response) *Response { 349 c := *r 350 c.Body = nil 351 c.TransferEncoding = nil 352 c.TLS = nil 353 c.Request = nil 354 return &c 355 } 356 357 type slurpResult struct { 358 io.ReadCloser 359 body []byte 360 err error 361 } 362 363 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) } 364 365 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) { 366 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" { 367 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0 368 } else { 369 t.Errorf("got %q response; want %q", res.Proto, wantProto) 370 } 371 slurp, err := io.ReadAll(res.Body) 372 373 res.Body.Close() 374 res.Body = slurpResult{ 375 ReadCloser: io.NopCloser(bytes.NewReader(slurp)), 376 body: slurp, 377 err: err, 378 } 379 for i, v := range res.Header["Date"] { 380 res.Header["Date"][i] = strings.Repeat("x", len(v)) 381 } 382 if res.Request == nil { 383 t.Errorf("for %s, no request", wantProto) 384 } 385 if (res.TLS != nil) != (wantProto == "HTTP/2.0") { 386 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil) 387 } 388 } 389 390 // Issue 13532 391 func TestH12_HeadContentLengthNoBody(t *testing.T) { 392 h12Compare{ 393 ReqFunc: (*Client).Head, 394 Handler: func(w ResponseWriter, r *Request) { 395 }, 396 }.run(t) 397 } 398 399 func TestH12_HeadContentLengthSmallBody(t *testing.T) { 400 h12Compare{ 401 ReqFunc: (*Client).Head, 402 Handler: func(w ResponseWriter, r *Request) { 403 io.WriteString(w, "small") 404 }, 405 }.run(t) 406 } 407 408 func TestH12_HeadContentLengthLargeBody(t *testing.T) { 409 h12Compare{ 410 ReqFunc: (*Client).Head, 411 Handler: func(w ResponseWriter, r *Request) { 412 chunk := strings.Repeat("x", 512<<10) 413 for i := 0; i < 10; i++ { 414 io.WriteString(w, chunk) 415 } 416 }, 417 }.run(t) 418 } 419 420 func TestH12_200NoBody(t *testing.T) { 421 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t) 422 } 423 424 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) } 425 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) } 426 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) } 427 428 func testH12_noBody(t *testing.T, status int) { 429 h12Compare{Handler: func(w ResponseWriter, r *Request) { 430 w.WriteHeader(status) 431 }}.run(t) 432 } 433 434 func TestH12_SmallBody(t *testing.T) { 435 h12Compare{Handler: func(w ResponseWriter, r *Request) { 436 io.WriteString(w, "small body") 437 }}.run(t) 438 } 439 440 func TestH12_ExplicitContentLength(t *testing.T) { 441 h12Compare{Handler: func(w ResponseWriter, r *Request) { 442 w.Header().Set("Content-Length", "3") 443 io.WriteString(w, "foo") 444 }}.run(t) 445 } 446 447 func TestH12_FlushBeforeBody(t *testing.T) { 448 h12Compare{Handler: func(w ResponseWriter, r *Request) { 449 w.(Flusher).Flush() 450 io.WriteString(w, "foo") 451 }}.run(t) 452 } 453 454 func TestH12_FlushMidBody(t *testing.T) { 455 h12Compare{Handler: func(w ResponseWriter, r *Request) { 456 io.WriteString(w, "foo") 457 w.(Flusher).Flush() 458 io.WriteString(w, "bar") 459 }}.run(t) 460 } 461 462 func TestH12_Head_ExplicitLen(t *testing.T) { 463 h12Compare{ 464 ReqFunc: (*Client).Head, 465 Handler: func(w ResponseWriter, r *Request) { 466 if r.Method != "HEAD" { 467 t.Errorf("unexpected method %q", r.Method) 468 } 469 w.Header().Set("Content-Length", "1235") 470 }, 471 }.run(t) 472 } 473 474 func TestH12_Head_ImplicitLen(t *testing.T) { 475 h12Compare{ 476 ReqFunc: (*Client).Head, 477 Handler: func(w ResponseWriter, r *Request) { 478 if r.Method != "HEAD" { 479 t.Errorf("unexpected method %q", r.Method) 480 } 481 io.WriteString(w, "foo") 482 }, 483 }.run(t) 484 } 485 486 func TestH12_HandlerWritesTooLittle(t *testing.T) { 487 h12Compare{ 488 Handler: func(w ResponseWriter, r *Request) { 489 w.Header().Set("Content-Length", "3") 490 io.WriteString(w, "12") // one byte short 491 }, 492 CheckResponse: func(proto string, res *Response) { 493 sr, ok := res.Body.(slurpResult) 494 if !ok { 495 t.Errorf("%s body is %T; want slurpResult", proto, res.Body) 496 return 497 } 498 if sr.err != io.ErrUnexpectedEOF { 499 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err) 500 } 501 if string(sr.body) != "12" { 502 t.Errorf("%s body = %q; want %q", proto, sr.body, "12") 503 } 504 }, 505 }.run(t) 506 } 507 508 // Tests that the HTTP/1 and HTTP/2 servers prevent handlers from 509 // writing more than they declared. This test does not test whether 510 // the transport deals with too much data, though, since the server 511 // doesn't make it possible to send bogus data. For those tests, see 512 // transport_test.go (for HTTP/1) or x/net/http2/transport_test.go 513 // (for HTTP/2). 514 func TestH12_HandlerWritesTooMuch(t *testing.T) { 515 h12Compare{ 516 Handler: func(w ResponseWriter, r *Request) { 517 w.Header().Set("Content-Length", "3") 518 w.(Flusher).Flush() 519 io.WriteString(w, "123") 520 w.(Flusher).Flush() 521 n, err := io.WriteString(w, "x") // too many 522 if n > 0 || err == nil { 523 t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err) 524 } 525 }, 526 }.run(t) 527 } 528 529 // Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip. 530 // Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298 531 func TestH12_AutoGzip(t *testing.T) { 532 h12Compare{ 533 Handler: func(w ResponseWriter, r *Request) { 534 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" { 535 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae) 536 } 537 w.Header().Set("Content-Encoding", "gzip") 538 gz := gzip.NewWriter(w) 539 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.") 540 gz.Close() 541 }, 542 }.run(t) 543 } 544 545 func TestH12_AutoGzip_Disabled(t *testing.T) { 546 h12Compare{ 547 Opts: []any{ 548 func(tr *Transport) { tr.DisableCompression = true }, 549 }, 550 Handler: func(w ResponseWriter, r *Request) { 551 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"]) 552 if ae := r.Header.Get("Accept-Encoding"); ae != "" { 553 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae) 554 } 555 }, 556 }.run(t) 557 } 558 559 // Test304Responses verifies that 304s don't declare that they're 560 // chunking in their response headers and aren't allowed to produce 561 // output. 562 func Test304Responses(t *testing.T) { run(t, test304Responses) } 563 func test304Responses(t *testing.T, mode testMode) { 564 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 565 w.WriteHeader(StatusNotModified) 566 _, err := w.Write([]byte("illegal body")) 567 if err != ErrBodyNotAllowed { 568 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) 569 } 570 })) 571 defer cst.close() 572 res, err := cst.c.Get(cst.ts.URL) 573 if err != nil { 574 t.Fatal(err) 575 } 576 if len(res.TransferEncoding) > 0 { 577 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) 578 } 579 body, err := io.ReadAll(res.Body) 580 if err != nil { 581 t.Error(err) 582 } 583 if len(body) > 0 { 584 t.Errorf("got unexpected body %q", string(body)) 585 } 586 } 587 588 func TestH12_ServerEmptyContentLength(t *testing.T) { 589 h12Compare{ 590 Handler: func(w ResponseWriter, r *Request) { 591 w.Header()["Content-Type"] = []string{""} 592 io.WriteString(w, "<html><body>hi</body></html>") 593 }, 594 }.run(t) 595 } 596 597 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) { 598 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4) 599 } 600 601 func TestH12_RequestContentLength_Known_Zero(t *testing.T) { 602 h12requestContentLength(t, func() io.Reader { return nil }, 0) 603 } 604 605 func TestH12_RequestContentLength_Unknown(t *testing.T) { 606 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1) 607 } 608 609 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) { 610 h12Compare{ 611 Handler: func(w ResponseWriter, r *Request) { 612 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength)) 613 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength) 614 }, 615 ReqFunc: func(c *Client, url string) (*Response, error) { 616 return c.Post(url, "text/plain", bodyfn()) 617 }, 618 CheckResponse: func(proto string, res *Response) { 619 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want { 620 t.Errorf("Proto %q got length %q; want %q", proto, got, want) 621 } 622 }, 623 }.run(t) 624 } 625 626 // Tests that closing the Request.Cancel channel also while still 627 // reading the response body. Issue 13159. 628 func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) } 629 func testCancelRequestMidBody(t *testing.T, mode testMode) { 630 unblock := make(chan bool) 631 didFlush := make(chan bool, 1) 632 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 633 io.WriteString(w, "Hello") 634 w.(Flusher).Flush() 635 didFlush <- true 636 <-unblock 637 io.WriteString(w, ", world.") 638 })) 639 defer close(unblock) 640 641 req, _ := NewRequest("GET", cst.ts.URL, nil) 642 cancel := make(chan struct{}) 643 req.Cancel = cancel 644 645 res, err := cst.c.Do(req) 646 if err != nil { 647 t.Fatal(err) 648 } 649 defer res.Body.Close() 650 <-didFlush 651 652 // Read a bit before we cancel. (Issue 13626) 653 // We should have "Hello" at least sitting there. 654 firstRead := make([]byte, 10) 655 n, err := res.Body.Read(firstRead) 656 if err != nil { 657 t.Fatal(err) 658 } 659 firstRead = firstRead[:n] 660 661 close(cancel) 662 663 rest, err := io.ReadAll(res.Body) 664 all := string(firstRead) + string(rest) 665 if all != "Hello" { 666 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest) 667 } 668 if err != ExportErrRequestCanceled { 669 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled) 670 } 671 } 672 673 // Tests that clients can send trailers to a server and that the server can read them. 674 func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) } 675 func testTrailersClientToServer(t *testing.T, mode testMode) { 676 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 677 var decl []string 678 for k := range r.Trailer { 679 decl = append(decl, k) 680 } 681 sort.Strings(decl) 682 683 slurp, err := io.ReadAll(r.Body) 684 if err != nil { 685 t.Errorf("Server reading request body: %v", err) 686 } 687 if string(slurp) != "foo" { 688 t.Errorf("Server read request body %q; want foo", slurp) 689 } 690 if r.Trailer == nil { 691 io.WriteString(w, "nil Trailer") 692 } else { 693 fmt.Fprintf(w, "decl: %v, vals: %s, %s", 694 decl, 695 r.Trailer.Get("Client-Trailer-A"), 696 r.Trailer.Get("Client-Trailer-B")) 697 } 698 })) 699 700 var req *Request 701 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader( 702 eofReaderFunc(func() { 703 req.Trailer["Client-Trailer-A"] = []string{"valuea"} 704 }), 705 strings.NewReader("foo"), 706 eofReaderFunc(func() { 707 req.Trailer["Client-Trailer-B"] = []string{"valueb"} 708 }), 709 )) 710 req.Trailer = Header{ 711 "Client-Trailer-A": nil, // to be set later 712 "Client-Trailer-B": nil, // to be set later 713 } 714 req.ContentLength = -1 715 res, err := cst.c.Do(req) 716 if err != nil { 717 t.Fatal(err) 718 } 719 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil { 720 t.Error(err) 721 } 722 } 723 724 // Tests that servers send trailers to a client and that the client can read them. 725 func TestTrailersServerToClient(t *testing.T) { 726 run(t, func(t *testing.T, mode testMode) { 727 testTrailersServerToClient(t, mode, false) 728 }) 729 } 730 func TestTrailersServerToClientFlush(t *testing.T) { 731 run(t, func(t *testing.T, mode testMode) { 732 testTrailersServerToClient(t, mode, true) 733 }) 734 } 735 736 func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) { 737 const body = "Some body" 738 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 739 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") 740 w.Header().Add("Trailer", "Server-Trailer-C") 741 742 io.WriteString(w, body) 743 if flush { 744 w.(Flusher).Flush() 745 } 746 747 // How handlers set Trailers: declare it ahead of time 748 // with the Trailer header, and then mutate the 749 // Header() of those values later, after the response 750 // has been written (we wrote to w above). 751 w.Header().Set("Server-Trailer-A", "valuea") 752 w.Header().Set("Server-Trailer-C", "valuec") // skipping B 753 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted") 754 })) 755 756 res, err := cst.c.Get(cst.ts.URL) 757 if err != nil { 758 t.Fatal(err) 759 } 760 761 wantHeader := Header{ 762 "Content-Type": {"text/plain; charset=utf-8"}, 763 } 764 wantLen := -1 765 if mode == http2Mode && !flush { 766 // In HTTP/1.1, any use of trailers forces HTTP/1.1 767 // chunking and a flush at the first write. That's 768 // unnecessary with HTTP/2's framing, so the server 769 // is able to calculate the length while still sending 770 // trailers afterwards. 771 wantLen = len(body) 772 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)} 773 } 774 if res.ContentLength != int64(wantLen) { 775 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen) 776 } 777 778 delete(res.Header, "Date") // irrelevant for test 779 if !reflect.DeepEqual(res.Header, wantHeader) { 780 t.Errorf("Header = %v; want %v", res.Header, wantHeader) 781 } 782 783 if got, want := res.Trailer, (Header{ 784 "Server-Trailer-A": nil, 785 "Server-Trailer-B": nil, 786 "Server-Trailer-C": nil, 787 }); !reflect.DeepEqual(got, want) { 788 t.Errorf("Trailer before body read = %v; want %v", got, want) 789 } 790 791 if err := wantBody(res, nil, body); err != nil { 792 t.Fatal(err) 793 } 794 795 if got, want := res.Trailer, (Header{ 796 "Server-Trailer-A": {"valuea"}, 797 "Server-Trailer-B": nil, 798 "Server-Trailer-C": {"valuec"}, 799 }); !reflect.DeepEqual(got, want) { 800 t.Errorf("Trailer after body read = %v; want %v", got, want) 801 } 802 } 803 804 // Don't allow a Body.Read after Body.Close. Issue 13648. 805 func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) } 806 func testResponseBodyReadAfterClose(t *testing.T, mode testMode) { 807 const body = "Some body" 808 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 809 io.WriteString(w, body) 810 })) 811 res, err := cst.c.Get(cst.ts.URL) 812 if err != nil { 813 t.Fatal(err) 814 } 815 res.Body.Close() 816 data, err := io.ReadAll(res.Body) 817 if len(data) != 0 || err == nil { 818 t.Fatalf("ReadAll returned %q, %v; want error", data, err) 819 } 820 } 821 822 func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) } 823 func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) { 824 const reqBody = "some request body" 825 const resBody = "some response body" 826 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 827 var wg sync.WaitGroup 828 wg.Add(2) 829 didRead := make(chan bool, 1) 830 // Read in one goroutine. 831 go func() { 832 defer wg.Done() 833 data, err := io.ReadAll(r.Body) 834 if string(data) != reqBody { 835 t.Errorf("Handler read %q; want %q", data, reqBody) 836 } 837 if err != nil { 838 t.Errorf("Handler Read: %v", err) 839 } 840 didRead <- true 841 }() 842 // Write in another goroutine. 843 go func() { 844 defer wg.Done() 845 if mode != http2Mode { 846 // our HTTP/1 implementation intentionally 847 // doesn't permit writes during read (mostly 848 // due to it being undefined); if that is ever 849 // relaxed, change this. 850 <-didRead 851 } 852 io.WriteString(w, resBody) 853 }() 854 wg.Wait() 855 })) 856 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody)) 857 req.Header.Add("Expect", "100-continue") // just to complicate things 858 res, err := cst.c.Do(req) 859 if err != nil { 860 t.Fatal(err) 861 } 862 data, err := io.ReadAll(res.Body) 863 defer res.Body.Close() 864 if err != nil { 865 t.Fatal(err) 866 } 867 if string(data) != resBody { 868 t.Errorf("read %q; want %q", data, resBody) 869 } 870 } 871 872 func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) } 873 func testConnectRequest(t *testing.T, mode testMode) { 874 gotc := make(chan *Request, 1) 875 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 876 gotc <- r 877 })) 878 879 u, err := url.Parse(cst.ts.URL) 880 if err != nil { 881 t.Fatal(err) 882 } 883 884 tests := []struct { 885 req *Request 886 want string 887 }{ 888 { 889 req: &Request{ 890 Method: "CONNECT", 891 Header: Header{}, 892 URL: u, 893 }, 894 want: u.Host, 895 }, 896 { 897 req: &Request{ 898 Method: "CONNECT", 899 Header: Header{}, 900 URL: u, 901 Host: "example.com:123", 902 }, 903 want: "example.com:123", 904 }, 905 } 906 907 for i, tt := range tests { 908 res, err := cst.c.Do(tt.req) 909 if err != nil { 910 t.Errorf("%d. RoundTrip = %v", i, err) 911 continue 912 } 913 res.Body.Close() 914 req := <-gotc 915 if req.Method != "CONNECT" { 916 t.Errorf("method = %q; want CONNECT", req.Method) 917 } 918 if req.Host != tt.want { 919 t.Errorf("Host = %q; want %q", req.Host, tt.want) 920 } 921 if req.URL.Host != tt.want { 922 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want) 923 } 924 } 925 } 926 927 func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) } 928 func testTransportUserAgent(t *testing.T, mode testMode) { 929 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 930 fmt.Fprintf(w, "%q", r.Header["User-Agent"]) 931 })) 932 933 either := func(a, b string) string { 934 if mode == http2Mode { 935 return b 936 } 937 return a 938 } 939 940 tests := []struct { 941 setup func(*Request) 942 want string 943 }{ 944 { 945 func(r *Request) {}, 946 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`), 947 }, 948 { 949 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") }, 950 `["foo/1.2.3"]`, 951 }, 952 { 953 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} }, 954 `["single"]`, 955 }, 956 { 957 func(r *Request) { r.Header.Set("User-Agent", "") }, 958 `[]`, 959 }, 960 { 961 func(r *Request) { r.Header["User-Agent"] = nil }, 962 `[]`, 963 }, 964 } 965 for i, tt := range tests { 966 req, _ := NewRequest("GET", cst.ts.URL, nil) 967 tt.setup(req) 968 res, err := cst.c.Do(req) 969 if err != nil { 970 t.Errorf("%d. RoundTrip = %v", i, err) 971 continue 972 } 973 slurp, err := io.ReadAll(res.Body) 974 res.Body.Close() 975 if err != nil { 976 t.Errorf("%d. read body = %v", i, err) 977 continue 978 } 979 if string(slurp) != tt.want { 980 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want) 981 } 982 } 983 } 984 985 func TestStarRequestMethod(t *testing.T) { 986 for _, method := range []string{"FOO", "OPTIONS"} { 987 t.Run(method, func(t *testing.T) { 988 run(t, func(t *testing.T, mode testMode) { 989 testStarRequest(t, method, mode) 990 }) 991 }) 992 } 993 } 994 func testStarRequest(t *testing.T, method string, mode testMode) { 995 gotc := make(chan *Request, 1) 996 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 997 w.Header().Set("foo", "bar") 998 gotc <- r 999 w.(Flusher).Flush() 1000 })) 1001 1002 u, err := url.Parse(cst.ts.URL) 1003 if err != nil { 1004 t.Fatal(err) 1005 } 1006 u.Path = "*" 1007 1008 req := &Request{ 1009 Method: method, 1010 Header: Header{}, 1011 URL: u, 1012 } 1013 1014 res, err := cst.c.Do(req) 1015 if err != nil { 1016 t.Fatalf("RoundTrip = %v", err) 1017 } 1018 res.Body.Close() 1019 1020 wantFoo := "bar" 1021 wantLen := int64(-1) 1022 if method == "OPTIONS" { 1023 wantFoo = "" 1024 wantLen = 0 1025 } 1026 if res.StatusCode != 200 { 1027 t.Errorf("status code = %v; want %d", res.Status, 200) 1028 } 1029 if res.ContentLength != wantLen { 1030 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen) 1031 } 1032 if got := res.Header.Get("foo"); got != wantFoo { 1033 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo) 1034 } 1035 select { 1036 case req = <-gotc: 1037 default: 1038 req = nil 1039 } 1040 if req == nil { 1041 if method != "OPTIONS" { 1042 t.Fatalf("handler never got request") 1043 } 1044 return 1045 } 1046 if req.Method != method { 1047 t.Errorf("method = %q; want %q", req.Method, method) 1048 } 1049 if req.URL.Path != "*" { 1050 t.Errorf("URL.Path = %q; want *", req.URL.Path) 1051 } 1052 if req.RequestURI != "*" { 1053 t.Errorf("RequestURI = %q; want *", req.RequestURI) 1054 } 1055 } 1056 1057 // Issue 13957 1058 func TestTransportDiscardsUnneededConns(t *testing.T) { 1059 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode}) 1060 } 1061 func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) { 1062 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1063 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) 1064 })) 1065 defer cst.close() 1066 1067 var numOpen, numClose int32 // atomic 1068 1069 tlsConfig := &tls.Config{InsecureSkipVerify: true} 1070 tr := &Transport{ 1071 TLSClientConfig: tlsConfig, 1072 DialTLS: func(_, addr string) (net.Conn, error) { 1073 time.Sleep(10 * time.Millisecond) 1074 rc, err := net.Dial("tcp", addr) 1075 if err != nil { 1076 return nil, err 1077 } 1078 atomic.AddInt32(&numOpen, 1) 1079 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }} 1080 return TLSClientFactory(c, tlsConfig), nil 1081 }, 1082 } 1083 if err := ExportHttp2ConfigureTransport(tr); err != nil { 1084 t.Fatal(err) 1085 } 1086 defer tr.CloseIdleConnections() 1087 1088 c := &Client{Transport: tr} 1089 1090 const N = 10 1091 gotBody := make(chan string, N) 1092 var wg sync.WaitGroup 1093 for i := 0; i < N; i++ { 1094 wg.Add(1) 1095 go func() { 1096 defer wg.Done() 1097 resp, err := c.Get(cst.ts.URL) 1098 if err != nil { 1099 // Try to work around spurious connection reset on loaded system. 1100 // See golang.org/issue/33585 and golang.org/issue/36797. 1101 time.Sleep(10 * time.Millisecond) 1102 resp, err = c.Get(cst.ts.URL) 1103 if err != nil { 1104 t.Errorf("Get: %v", err) 1105 return 1106 } 1107 } 1108 defer resp.Body.Close() 1109 slurp, err := io.ReadAll(resp.Body) 1110 if err != nil { 1111 t.Error(err) 1112 } 1113 gotBody <- string(slurp) 1114 }() 1115 } 1116 wg.Wait() 1117 close(gotBody) 1118 1119 var last string 1120 for got := range gotBody { 1121 if last == "" { 1122 last = got 1123 continue 1124 } 1125 if got != last { 1126 t.Errorf("Response body changed: %q -> %q", last, got) 1127 } 1128 } 1129 1130 var open, close int32 1131 for i := 0; i < 150; i++ { 1132 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose) 1133 if open < 1 { 1134 t.Fatalf("open = %d; want at least", open) 1135 } 1136 if close == open-1 { 1137 // Success 1138 return 1139 } 1140 time.Sleep(10 * time.Millisecond) 1141 } 1142 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1) 1143 } 1144 1145 // tests that Transport doesn't retain a pointer to the provided request. 1146 func TestTransportGCRequest(t *testing.T) { 1147 run(t, func(t *testing.T, mode testMode) { 1148 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) }) 1149 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) }) 1150 }) 1151 } 1152 func testTransportGCRequest(t *testing.T, mode testMode, body bool) { 1153 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1154 io.ReadAll(r.Body) 1155 if body { 1156 io.WriteString(w, "Hello.") 1157 } 1158 })) 1159 1160 didGC := make(chan struct{}) 1161 (func() { 1162 body := strings.NewReader("some body") 1163 req, _ := NewRequest("POST", cst.ts.URL, body) 1164 runtime.SetFinalizer(req, func(*Request) { close(didGC) }) 1165 res, err := cst.c.Do(req) 1166 if err != nil { 1167 t.Fatal(err) 1168 } 1169 if _, err := io.ReadAll(res.Body); err != nil { 1170 t.Fatal(err) 1171 } 1172 if err := res.Body.Close(); err != nil { 1173 t.Fatal(err) 1174 } 1175 })() 1176 timeout := time.NewTimer(5 * time.Second) 1177 defer timeout.Stop() 1178 for { 1179 select { 1180 case <-didGC: 1181 return 1182 case <-time.After(100 * time.Millisecond): 1183 runtime.GC() 1184 case <-timeout.C: 1185 t.Fatal("never saw GC of request") 1186 } 1187 } 1188 } 1189 1190 func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) } 1191 func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) { 1192 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1193 fmt.Fprintf(w, "Handler saw headers: %q", r.Header) 1194 }), optQuietLog) 1195 cst.tr.DisableKeepAlives = true 1196 1197 tests := []struct { 1198 key, val string 1199 ok bool 1200 }{ 1201 {"Foo", "capital-key", true}, // verify h2 allows capital keys 1202 {"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed 1203 {"Foo", "two\nlines", false}, // \n byte in value not allowed 1204 {"bogus\nkey", "v", false}, // \n byte also not allowed in key 1205 {"A space", "v", false}, // spaces in keys not allowed 1206 {"имя", "v", false}, // key must be ascii 1207 {"name", "валю", true}, // value may be non-ascii 1208 {"", "v", false}, // key must be non-empty 1209 {"k", "", true}, // value may be empty 1210 } 1211 for _, tt := range tests { 1212 dialedc := make(chan bool, 1) 1213 cst.tr.Dial = func(netw, addr string) (net.Conn, error) { 1214 dialedc <- true 1215 return net.Dial(netw, addr) 1216 } 1217 req, _ := NewRequest("GET", cst.ts.URL, nil) 1218 req.Header[tt.key] = []string{tt.val} 1219 res, err := cst.c.Do(req) 1220 var body []byte 1221 if err == nil { 1222 body, _ = io.ReadAll(res.Body) 1223 res.Body.Close() 1224 } 1225 var dialed bool 1226 select { 1227 case <-dialedc: 1228 dialed = true 1229 default: 1230 } 1231 1232 if !tt.ok && dialed { 1233 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body) 1234 } else if (err == nil) != tt.ok { 1235 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok) 1236 } 1237 } 1238 } 1239 1240 func TestInterruptWithPanic(t *testing.T) { 1241 run(t, func(t *testing.T, mode testMode) { 1242 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") }) 1243 t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) }) 1244 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) }) 1245 }, testNotParallel) 1246 } 1247 func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) { 1248 const msg = "hello" 1249 1250 testDone := make(chan struct{}) 1251 defer close(testDone) 1252 1253 var errorLog lockedBytesBuffer 1254 gotHeaders := make(chan bool, 1) 1255 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1256 io.WriteString(w, msg) 1257 w.(Flusher).Flush() 1258 1259 select { 1260 case <-gotHeaders: 1261 case <-testDone: 1262 } 1263 panic(panicValue) 1264 }), func(ts *httptest.Server) { 1265 ts.Config.ErrorLog = log.New(&errorLog, "", 0) 1266 }) 1267 res, err := cst.c.Get(cst.ts.URL) 1268 if err != nil { 1269 t.Fatal(err) 1270 } 1271 gotHeaders <- true 1272 defer res.Body.Close() 1273 slurp, err := io.ReadAll(res.Body) 1274 if string(slurp) != msg { 1275 t.Errorf("client read %q; want %q", slurp, msg) 1276 } 1277 if err == nil { 1278 t.Errorf("client read all successfully; want some error") 1279 } 1280 logOutput := func() string { 1281 errorLog.Lock() 1282 defer errorLog.Unlock() 1283 return errorLog.String() 1284 } 1285 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler 1286 1287 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { 1288 gotLog := logOutput() 1289 if !wantStackLogged { 1290 if gotLog == "" { 1291 return true 1292 } 1293 t.Fatalf("want no log output; got: %s", gotLog) 1294 } 1295 if gotLog == "" { 1296 if d > 0 { 1297 t.Logf("wanted a stack trace logged; got nothing after %v", d) 1298 } 1299 return false 1300 } 1301 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 { 1302 if d > 0 { 1303 t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog) 1304 } 1305 return false 1306 } 1307 return true 1308 }) 1309 } 1310 1311 type lockedBytesBuffer struct { 1312 sync.Mutex 1313 bytes.Buffer 1314 } 1315 1316 func (b *lockedBytesBuffer) Write(p []byte) (int, error) { 1317 b.Lock() 1318 defer b.Unlock() 1319 return b.Buffer.Write(p) 1320 } 1321 1322 // Issue 15366 1323 func TestH12_AutoGzipWithDumpResponse(t *testing.T) { 1324 h12Compare{ 1325 Handler: func(w ResponseWriter, r *Request) { 1326 h := w.Header() 1327 h.Set("Content-Encoding", "gzip") 1328 h.Set("Content-Length", "23") 1329 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00") 1330 }, 1331 EarlyCheckResponse: func(proto string, res *Response) { 1332 if !res.Uncompressed { 1333 t.Errorf("%s: expected Uncompressed to be set", proto) 1334 } 1335 dump, err := httputil.DumpResponse(res, true) 1336 if err != nil { 1337 t.Errorf("%s: DumpResponse: %v", proto, err) 1338 return 1339 } 1340 if strings.Contains(string(dump), "Connection: close") { 1341 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump) 1342 } 1343 if !strings.Contains(string(dump), "FOO") { 1344 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump) 1345 } 1346 }, 1347 }.run(t) 1348 } 1349 1350 // Issue 14607 1351 func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) } 1352 func testCloseIdleConnections(t *testing.T, mode testMode) { 1353 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1354 w.Header().Set("X-Addr", r.RemoteAddr) 1355 })) 1356 get := func() string { 1357 res, err := cst.c.Get(cst.ts.URL) 1358 if err != nil { 1359 t.Fatal(err) 1360 } 1361 res.Body.Close() 1362 v := res.Header.Get("X-Addr") 1363 if v == "" { 1364 t.Fatal("didn't get X-Addr") 1365 } 1366 return v 1367 } 1368 a1 := get() 1369 cst.tr.CloseIdleConnections() 1370 a2 := get() 1371 if a1 == a2 { 1372 t.Errorf("didn't close connection") 1373 } 1374 } 1375 1376 type noteCloseConn struct { 1377 net.Conn 1378 closeFunc func() 1379 } 1380 1381 func (x noteCloseConn) Close() error { 1382 x.closeFunc() 1383 return x.Conn.Close() 1384 } 1385 1386 type testErrorReader struct{ t *testing.T } 1387 1388 func (r testErrorReader) Read(p []byte) (n int, err error) { 1389 r.t.Error("unexpected Read call") 1390 return 0, io.EOF 1391 } 1392 1393 func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) } 1394 func testNoSniffExpectRequestBody(t *testing.T, mode testMode) { 1395 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1396 w.WriteHeader(StatusUnauthorized) 1397 })) 1398 1399 // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. 1400 cst.tr.ExpectContinueTimeout = 10 * time.Second 1401 1402 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t}) 1403 if err != nil { 1404 t.Fatal(err) 1405 } 1406 req.ContentLength = 0 // so transport is tempted to sniff it 1407 req.Header.Set("Expect", "100-continue") 1408 res, err := cst.tr.RoundTrip(req) 1409 if err != nil { 1410 t.Fatal(err) 1411 } 1412 defer res.Body.Close() 1413 if res.StatusCode != StatusUnauthorized { 1414 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized) 1415 } 1416 } 1417 1418 func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) } 1419 func testServerUndeclaredTrailers(t *testing.T, mode testMode) { 1420 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1421 w.Header().Set("Foo", "Bar") 1422 w.Header().Set("Trailer:Foo", "Baz") 1423 w.(Flusher).Flush() 1424 w.Header().Add("Trailer:Foo", "Baz2") 1425 w.Header().Set("Trailer:Bar", "Quux") 1426 })) 1427 res, err := cst.c.Get(cst.ts.URL) 1428 if err != nil { 1429 t.Fatal(err) 1430 } 1431 if _, err := io.Copy(io.Discard, res.Body); err != nil { 1432 t.Fatal(err) 1433 } 1434 res.Body.Close() 1435 delete(res.Header, "Date") 1436 delete(res.Header, "Content-Type") 1437 1438 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) { 1439 t.Errorf("Header = %#v; want %#v", res.Header, want) 1440 } 1441 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) { 1442 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want) 1443 } 1444 } 1445 1446 func TestBadResponseAfterReadingBody(t *testing.T) { 1447 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode}) 1448 } 1449 func testBadResponseAfterReadingBody(t *testing.T, mode testMode) { 1450 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1451 _, err := io.Copy(io.Discard, r.Body) 1452 if err != nil { 1453 t.Fatal(err) 1454 } 1455 c, _, err := w.(Hijacker).Hijack() 1456 if err != nil { 1457 t.Fatal(err) 1458 } 1459 defer c.Close() 1460 fmt.Fprintln(c, "some bogus crap") 1461 })) 1462 1463 closes := 0 1464 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) 1465 if err == nil { 1466 res.Body.Close() 1467 t.Fatal("expected an error to be returned from Post") 1468 } 1469 if closes != 1 { 1470 t.Errorf("closes = %d; want 1", closes) 1471 } 1472 } 1473 1474 func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) } 1475 func testWriteHeader0(t *testing.T, mode testMode) { 1476 gotpanic := make(chan bool, 1) 1477 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1478 defer close(gotpanic) 1479 defer func() { 1480 if e := recover(); e != nil { 1481 got := fmt.Sprintf("%T, %v", e, e) 1482 want := "string, invalid WriteHeader code 0" 1483 if got != want { 1484 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want) 1485 } 1486 gotpanic <- true 1487 1488 // Set an explicit 503. This also tests that the WriteHeader call panics 1489 // before it recorded that an explicit value was set and that bogus 1490 // value wasn't stuck. 1491 w.WriteHeader(503) 1492 } 1493 }() 1494 w.WriteHeader(0) 1495 })) 1496 res, err := cst.c.Get(cst.ts.URL) 1497 if err != nil { 1498 t.Fatal(err) 1499 } 1500 if res.StatusCode != 503 { 1501 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status) 1502 } 1503 if !<-gotpanic { 1504 t.Error("expected panic in handler") 1505 } 1506 } 1507 1508 // Issue 23010: don't be super strict checking WriteHeader's code if 1509 // it's not even valid to call WriteHeader then anyway. 1510 func TestWriteHeaderNoCodeCheck(t *testing.T) { 1511 run(t, func(t *testing.T, mode testMode) { 1512 testWriteHeaderAfterWrite(t, mode, false) 1513 }) 1514 } 1515 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { 1516 testWriteHeaderAfterWrite(t, http1Mode, true) 1517 } 1518 func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) { 1519 var errorLog lockedBytesBuffer 1520 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1521 if hijack { 1522 conn, _, _ := w.(Hijacker).Hijack() 1523 defer conn.Close() 1524 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo")) 1525 w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 1526 conn.Write([]byte("bar")) 1527 return 1528 } 1529 io.WriteString(w, "foo") 1530 w.(Flusher).Flush() 1531 w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 1532 io.WriteString(w, "bar") 1533 }), func(ts *httptest.Server) { 1534 ts.Config.ErrorLog = log.New(&errorLog, "", 0) 1535 }) 1536 res, err := cst.c.Get(cst.ts.URL) 1537 if err != nil { 1538 t.Fatal(err) 1539 } 1540 defer res.Body.Close() 1541 body, err := io.ReadAll(res.Body) 1542 if err != nil { 1543 t.Fatal(err) 1544 } 1545 if got, want := string(body), "foobar"; got != want { 1546 t.Errorf("got = %q; want %q", got, want) 1547 } 1548 1549 // Also check the stderr output: 1550 if mode == http2Mode { 1551 // TODO: also emit this log message for HTTP/2? 1552 // We historically haven't, so don't check. 1553 return 1554 } 1555 gotLog := strings.TrimSpace(errorLog.String()) 1556 wantLog := "http: superfluous response.WriteHeader call from github.com/ooni/oohttp.relevantCaller (server.go:" 1557 if hijack { 1558 wantLog = "http: response.WriteHeader on hijacked connection from github.com/ooni/oohttp.relevantCaller (server.go:" 1559 } 1560 if !strings.HasPrefix(gotLog, wantLog) { 1561 t.Errorf("stderr output = %q; want %q", gotLog, wantLog) 1562 } 1563 } 1564 1565 func TestBidiStreamReverseProxy(t *testing.T) { 1566 run(t, testBidiStreamReverseProxy, []testMode{http2Mode}) 1567 } 1568 func testBidiStreamReverseProxy(t *testing.T, mode testMode) { 1569 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1570 if _, err := io.Copy(w, r.Body); err != nil { 1571 log.Printf("bidi backend copy: %v", err) 1572 } 1573 })) 1574 1575 backURL, err := url.Parse(backend.ts.URL) 1576 if err != nil { 1577 t.Fatal(err) 1578 } 1579 rp := httputil.NewSingleHostReverseProxy(backURL) 1580 rp.Transport = backend.tr 1581 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1582 rp.ServeHTTP(w, r) 1583 })) 1584 1585 bodyRes := make(chan any, 1) // error or hash.Hash 1586 pr, pw := io.Pipe() 1587 req, _ := NewRequest("PUT", proxy.ts.URL, pr) 1588 const size = 4 << 20 1589 go func() { 1590 h := sha1.New() 1591 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size) 1592 go pw.Close() 1593 if err != nil { 1594 bodyRes <- err 1595 } else { 1596 bodyRes <- h 1597 } 1598 }() 1599 res, err := backend.c.Do(req) 1600 if err != nil { 1601 t.Fatal(err) 1602 } 1603 defer res.Body.Close() 1604 hgot := sha1.New() 1605 n, err := io.Copy(hgot, res.Body) 1606 if err != nil { 1607 t.Fatal(err) 1608 } 1609 if n != size { 1610 t.Fatalf("got %d bytes; want %d", n, size) 1611 } 1612 select { 1613 case v := <-bodyRes: 1614 switch v := v.(type) { 1615 default: 1616 t.Fatalf("body copy: %v", err) 1617 case hash.Hash: 1618 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) { 1619 t.Errorf("written bytes didn't match received bytes") 1620 } 1621 } 1622 case <-time.After(10 * time.Second): 1623 t.Fatal("timeout") 1624 } 1625 1626 } 1627 1628 // Always use HTTP/1.1 for WebSocket upgrades. 1629 func TestH12_WebSocketUpgrade(t *testing.T) { 1630 h12Compare{ 1631 Handler: func(w ResponseWriter, r *Request) { 1632 h := w.Header() 1633 h.Set("Foo", "bar") 1634 }, 1635 ReqFunc: func(c *Client, url string) (*Response, error) { 1636 req, _ := NewRequest("GET", url, nil) 1637 req.Header.Set("Connection", "Upgrade") 1638 req.Header.Set("Upgrade", "WebSocket") 1639 return c.Do(req) 1640 }, 1641 EarlyCheckResponse: func(proto string, res *Response) { 1642 if res.Proto != "HTTP/1.1" { 1643 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto) 1644 } 1645 res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0 1646 }, 1647 }.run(t) 1648 } 1649 1650 func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) } 1651 func testIdentityTransferEncoding(t *testing.T, mode testMode) { 1652 const body = "body" 1653 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1654 gotBody, _ := io.ReadAll(r.Body) 1655 if got, want := string(gotBody), body; got != want { 1656 t.Errorf("got request body = %q; want %q", got, want) 1657 } 1658 w.Header().Set("Transfer-Encoding", "identity") 1659 w.WriteHeader(StatusOK) 1660 w.(Flusher).Flush() 1661 io.WriteString(w, body) 1662 })) 1663 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body)) 1664 res, err := cst.c.Do(req) 1665 if err != nil { 1666 t.Fatal(err) 1667 } 1668 defer res.Body.Close() 1669 gotBody, err := io.ReadAll(res.Body) 1670 if err != nil { 1671 t.Fatal(err) 1672 } 1673 if got, want := string(gotBody), body; got != want { 1674 t.Errorf("got response body = %q; want %q", got, want) 1675 } 1676 } 1677 1678 func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) } 1679 func testEarlyHintsRequest(t *testing.T, mode testMode) { 1680 var wg sync.WaitGroup 1681 wg.Add(1) 1682 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1683 h := w.Header() 1684 1685 h.Add("Content-Length", "123") // must be ignored 1686 h.Add("Link", "</style.css>; rel=preload; as=style") 1687 h.Add("Link", "</script.js>; rel=preload; as=script") 1688 w.WriteHeader(StatusEarlyHints) 1689 1690 wg.Wait() 1691 1692 h.Add("Link", "</foo.js>; rel=preload; as=script") 1693 w.WriteHeader(StatusEarlyHints) 1694 1695 w.Write([]byte("Hello")) 1696 })) 1697 1698 checkLinkHeaders := func(t *testing.T, expected, got []string) { 1699 t.Helper() 1700 1701 if len(expected) != len(got) { 1702 t.Errorf("got %d expected %d", len(got), len(expected)) 1703 } 1704 1705 for i := range expected { 1706 if expected[i] != got[i] { 1707 t.Errorf("got %q expected %q", got[i], expected[i]) 1708 } 1709 } 1710 } 1711 1712 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) { 1713 t.Helper() 1714 1715 for _, h := range []string{"Content-Length", "Transfer-Encoding"} { 1716 if v, ok := header[h]; ok { 1717 t.Errorf("%s is %q; must not be sent", h, v) 1718 } 1719 } 1720 } 1721 1722 var respCounter uint8 1723 trace := &httptrace.ClientTrace{ 1724 Got1xxResponse: func(code int, header textproto.MIMEHeader) error { 1725 switch respCounter { 1726 case 0: 1727 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"]) 1728 checkExcludedHeaders(t, header) 1729 1730 wg.Done() 1731 case 1: 1732 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"]) 1733 checkExcludedHeaders(t, header) 1734 1735 default: 1736 t.Error("Unexpected 1xx response") 1737 } 1738 1739 respCounter++ 1740 1741 return nil 1742 }, 1743 } 1744 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil) 1745 1746 res, err := cst.c.Do(req) 1747 if err != nil { 1748 t.Fatal(err) 1749 } 1750 defer res.Body.Close() 1751 1752 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"]) 1753 if cl := res.Header.Get("Content-Length"); cl != "123" { 1754 t.Errorf("Content-Length is %q; want 123", cl) 1755 } 1756 1757 body, _ := io.ReadAll(res.Body) 1758 if string(body) != "Hello" { 1759 t.Errorf("Read body %q; want Hello", body) 1760 } 1761 }