github.com/mtsmfm/go/src@v0.0.0-20221020090648-44bdcb9f8fde/net/http/clientserver_test.go (about)

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode.
     6  
     7  package http_test
     8  
     9  import (
    10  	"bytes"
    11  	"compress/gzip"
    12  	"context"
    13  	"crypto/rand"
    14  	"crypto/sha1"
    15  	"crypto/tls"
    16  	"fmt"
    17  	"hash"
    18  	"io"
    19  	"log"
    20  	"net"
    21  	. "net/http"
    22  	"net/http/httptest"
    23  	"net/http/httptrace"
    24  	"net/http/httputil"
    25  	"net/textproto"
    26  	"net/url"
    27  	"os"
    28  	"reflect"
    29  	"runtime"
    30  	"sort"
    31  	"strings"
    32  	"sync"
    33  	"sync/atomic"
    34  	"testing"
    35  	"time"
    36  )
    37  
    38  type testMode string
    39  
    40  const (
    41  	http1Mode  = testMode("h1")     // HTTP/1.1
    42  	https1Mode = testMode("https1") // HTTPS/1.1
    43  	http2Mode  = testMode("h2")     // HTTP/2
    44  )
    45  
    46  type testNotParallelOpt struct{}
    47  
    48  var (
    49  	testNotParallel = testNotParallelOpt{}
    50  )
    51  
    52  type TBRun[T any] interface {
    53  	testing.TB
    54  	Run(string, func(T)) bool
    55  }
    56  
    57  // run runs a client/server test in a variety of test configurations.
    58  //
    59  // Tests execute in HTTP/1.1 and HTTP/2 modes by default.
    60  // To run in a different set of configurations, pass a []testMode option.
    61  //
    62  // Tests call t.Parallel() by default.
    63  // To disable parallel execution, pass the testNotParallel option.
    64  func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
    65  	t.Helper()
    66  	modes := []testMode{http1Mode, http2Mode}
    67  	parallel := true
    68  	for _, opt := range opts {
    69  		switch opt := opt.(type) {
    70  		case []testMode:
    71  			modes = opt
    72  		case testNotParallelOpt:
    73  			parallel = false
    74  		default:
    75  			t.Fatalf("unknown option type %T", opt)
    76  		}
    77  	}
    78  	if t, ok := any(t).(*testing.T); ok && parallel {
    79  		setParallel(t)
    80  	}
    81  	for _, mode := range modes {
    82  		t.Run(string(mode), func(t T) {
    83  			t.Helper()
    84  			if t, ok := any(t).(*testing.T); ok && parallel {
    85  				setParallel(t)
    86  			}
    87  			t.Cleanup(func() {
    88  				afterTest(t)
    89  			})
    90  			f(t, mode)
    91  		})
    92  	}
    93  }
    94  
    95  type clientServerTest struct {
    96  	t  testing.TB
    97  	h2 bool
    98  	h  Handler
    99  	ts *httptest.Server
   100  	tr *Transport
   101  	c  *Client
   102  }
   103  
   104  func (t *clientServerTest) close() {
   105  	t.tr.CloseIdleConnections()
   106  	t.ts.Close()
   107  }
   108  
   109  func (t *clientServerTest) getURL(u string) string {
   110  	res, err := t.c.Get(u)
   111  	if err != nil {
   112  		t.t.Fatal(err)
   113  	}
   114  	defer res.Body.Close()
   115  	slurp, err := io.ReadAll(res.Body)
   116  	if err != nil {
   117  		t.t.Fatal(err)
   118  	}
   119  	return string(slurp)
   120  }
   121  
   122  func (t *clientServerTest) scheme() string {
   123  	if t.h2 {
   124  		return "https"
   125  	}
   126  	return "http"
   127  }
   128  
   129  var optQuietLog = func(ts *httptest.Server) {
   130  	ts.Config.ErrorLog = quietLog
   131  }
   132  
   133  func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
   134  	return func(ts *httptest.Server) {
   135  		ts.Config.ErrorLog = lg
   136  	}
   137  }
   138  
   139  // newClientServerTest creates and starts an httptest.Server.
   140  //
   141  // The mode parameter selects the implementation to test:
   142  // HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use
   143  // the 'run' function, which will start a subtests for each tested mode.
   144  //
   145  // The vararg opts parameter can include functions to configure the
   146  // test server or transport.
   147  //
   148  //	func(*httptest.Server) // run before starting the server
   149  //	func(*http.Transport)
   150  func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
   151  	if mode == http2Mode {
   152  		CondSkipHTTP2(t)
   153  	}
   154  	cst := &clientServerTest{
   155  		t:  t,
   156  		h2: mode == http2Mode,
   157  		h:  h,
   158  	}
   159  	cst.ts = httptest.NewUnstartedServer(h)
   160  
   161  	var transportFuncs []func(*Transport)
   162  	for _, opt := range opts {
   163  		switch opt := opt.(type) {
   164  		case func(*Transport):
   165  			transportFuncs = append(transportFuncs, opt)
   166  		case func(*httptest.Server):
   167  			opt(cst.ts)
   168  		default:
   169  			t.Fatalf("unhandled option type %T", opt)
   170  		}
   171  	}
   172  
   173  	switch mode {
   174  	case http1Mode:
   175  		cst.ts.Start()
   176  	case https1Mode:
   177  		cst.ts.StartTLS()
   178  	case http2Mode:
   179  		ExportHttp2ConfigureServer(cst.ts.Config, nil)
   180  		cst.ts.TLS = cst.ts.Config.TLSConfig
   181  		cst.ts.StartTLS()
   182  	default:
   183  		t.Fatalf("unknown test mode %v", mode)
   184  	}
   185  	cst.c = cst.ts.Client()
   186  	cst.tr = cst.c.Transport.(*Transport)
   187  	if mode == http2Mode {
   188  		if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
   189  			t.Fatal(err)
   190  		}
   191  	}
   192  	for _, f := range transportFuncs {
   193  		f(cst.tr)
   194  	}
   195  	t.Cleanup(func() {
   196  		cst.close()
   197  	})
   198  	return cst
   199  }
   200  
   201  // Testing the newClientServerTest helper itself.
   202  func TestNewClientServerTest(t *testing.T) {
   203  	run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode})
   204  }
   205  func testNewClientServerTest(t *testing.T, mode testMode) {
   206  	var got struct {
   207  		sync.Mutex
   208  		proto  string
   209  		hasTLS bool
   210  	}
   211  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
   212  		got.Lock()
   213  		defer got.Unlock()
   214  		got.proto = r.Proto
   215  		got.hasTLS = r.TLS != nil
   216  	})
   217  	cst := newClientServerTest(t, mode, h)
   218  	if _, err := cst.c.Head(cst.ts.URL); err != nil {
   219  		t.Fatal(err)
   220  	}
   221  	var wantProto string
   222  	var wantTLS bool
   223  	switch mode {
   224  	case http1Mode:
   225  		wantProto = "HTTP/1.1"
   226  		wantTLS = false
   227  	case https1Mode:
   228  		wantProto = "HTTP/1.1"
   229  		wantTLS = true
   230  	case http2Mode:
   231  		wantProto = "HTTP/2.0"
   232  		wantTLS = true
   233  	}
   234  	if got.proto != wantProto {
   235  		t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
   236  	}
   237  	if got.hasTLS != wantTLS {
   238  		t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
   239  	}
   240  }
   241  
   242  func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
   243  func testChunkedResponseHeaders(t *testing.T, mode testMode) {
   244  	log.SetOutput(io.Discard) // is noisy otherwise
   245  	defer log.SetOutput(os.Stderr)
   246  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   247  		w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
   248  		w.(Flusher).Flush()
   249  		fmt.Fprintf(w, "I am a chunked response.")
   250  	}))
   251  
   252  	res, err := cst.c.Get(cst.ts.URL)
   253  	if err != nil {
   254  		t.Fatalf("Get error: %v", err)
   255  	}
   256  	defer res.Body.Close()
   257  	if g, e := res.ContentLength, int64(-1); g != e {
   258  		t.Errorf("expected ContentLength of %d; got %d", e, g)
   259  	}
   260  	wantTE := []string{"chunked"}
   261  	if mode == http2Mode {
   262  		wantTE = nil
   263  	}
   264  	if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
   265  		t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
   266  	}
   267  	if got, haveCL := res.Header["Content-Length"]; haveCL {
   268  		t.Errorf("Unexpected Content-Length: %q", got)
   269  	}
   270  }
   271  
   272  type reqFunc func(c *Client, url string) (*Response, error)
   273  
   274  // h12Compare is a test that compares HTTP/1 and HTTP/2 behavior
   275  // against each other.
   276  type h12Compare struct {
   277  	Handler            func(ResponseWriter, *Request)    // required
   278  	ReqFunc            reqFunc                           // optional
   279  	CheckResponse      func(proto string, res *Response) // optional
   280  	EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize
   281  	Opts               []any
   282  }
   283  
   284  func (tt h12Compare) reqFunc() reqFunc {
   285  	if tt.ReqFunc == nil {
   286  		return (*Client).Get
   287  	}
   288  	return tt.ReqFunc
   289  }
   290  
   291  func (tt h12Compare) run(t *testing.T) {
   292  	setParallel(t)
   293  	cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
   294  	defer cst1.close()
   295  	cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
   296  	defer cst2.close()
   297  
   298  	res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
   299  	if err != nil {
   300  		t.Errorf("HTTP/1 request: %v", err)
   301  		return
   302  	}
   303  	res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
   304  	if err != nil {
   305  		t.Errorf("HTTP/2 request: %v", err)
   306  		return
   307  	}
   308  
   309  	if fn := tt.EarlyCheckResponse; fn != nil {
   310  		fn("HTTP/1.1", res1)
   311  		fn("HTTP/2.0", res2)
   312  	}
   313  
   314  	tt.normalizeRes(t, res1, "HTTP/1.1")
   315  	tt.normalizeRes(t, res2, "HTTP/2.0")
   316  	res1body, res2body := res1.Body, res2.Body
   317  
   318  	eres1 := mostlyCopy(res1)
   319  	eres2 := mostlyCopy(res2)
   320  	if !reflect.DeepEqual(eres1, eres2) {
   321  		t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
   322  			cst1.ts.URL, eres1, cst2.ts.URL, eres2)
   323  	}
   324  	if !reflect.DeepEqual(res1body, res2body) {
   325  		t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
   326  	}
   327  	if fn := tt.CheckResponse; fn != nil {
   328  		res1.Body, res2.Body = res1body, res2body
   329  		fn("HTTP/1.1", res1)
   330  		fn("HTTP/2.0", res2)
   331  	}
   332  }
   333  
   334  func mostlyCopy(r *Response) *Response {
   335  	c := *r
   336  	c.Body = nil
   337  	c.TransferEncoding = nil
   338  	c.TLS = nil
   339  	c.Request = nil
   340  	return &c
   341  }
   342  
   343  type slurpResult struct {
   344  	io.ReadCloser
   345  	body []byte
   346  	err  error
   347  }
   348  
   349  func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
   350  
   351  func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
   352  	if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
   353  		res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
   354  	} else {
   355  		t.Errorf("got %q response; want %q", res.Proto, wantProto)
   356  	}
   357  	slurp, err := io.ReadAll(res.Body)
   358  
   359  	res.Body.Close()
   360  	res.Body = slurpResult{
   361  		ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
   362  		body:       slurp,
   363  		err:        err,
   364  	}
   365  	for i, v := range res.Header["Date"] {
   366  		res.Header["Date"][i] = strings.Repeat("x", len(v))
   367  	}
   368  	if res.Request == nil {
   369  		t.Errorf("for %s, no request", wantProto)
   370  	}
   371  	if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
   372  		t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
   373  	}
   374  }
   375  
   376  // Issue 13532
   377  func TestH12_HeadContentLengthNoBody(t *testing.T) {
   378  	h12Compare{
   379  		ReqFunc: (*Client).Head,
   380  		Handler: func(w ResponseWriter, r *Request) {
   381  		},
   382  	}.run(t)
   383  }
   384  
   385  func TestH12_HeadContentLengthSmallBody(t *testing.T) {
   386  	h12Compare{
   387  		ReqFunc: (*Client).Head,
   388  		Handler: func(w ResponseWriter, r *Request) {
   389  			io.WriteString(w, "small")
   390  		},
   391  	}.run(t)
   392  }
   393  
   394  func TestH12_HeadContentLengthLargeBody(t *testing.T) {
   395  	h12Compare{
   396  		ReqFunc: (*Client).Head,
   397  		Handler: func(w ResponseWriter, r *Request) {
   398  			chunk := strings.Repeat("x", 512<<10)
   399  			for i := 0; i < 10; i++ {
   400  				io.WriteString(w, chunk)
   401  			}
   402  		},
   403  	}.run(t)
   404  }
   405  
   406  func TestH12_200NoBody(t *testing.T) {
   407  	h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
   408  }
   409  
   410  func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
   411  func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
   412  func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
   413  
   414  func testH12_noBody(t *testing.T, status int) {
   415  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   416  		w.WriteHeader(status)
   417  	}}.run(t)
   418  }
   419  
   420  func TestH12_SmallBody(t *testing.T) {
   421  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   422  		io.WriteString(w, "small body")
   423  	}}.run(t)
   424  }
   425  
   426  func TestH12_ExplicitContentLength(t *testing.T) {
   427  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   428  		w.Header().Set("Content-Length", "3")
   429  		io.WriteString(w, "foo")
   430  	}}.run(t)
   431  }
   432  
   433  func TestH12_FlushBeforeBody(t *testing.T) {
   434  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   435  		w.(Flusher).Flush()
   436  		io.WriteString(w, "foo")
   437  	}}.run(t)
   438  }
   439  
   440  func TestH12_FlushMidBody(t *testing.T) {
   441  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   442  		io.WriteString(w, "foo")
   443  		w.(Flusher).Flush()
   444  		io.WriteString(w, "bar")
   445  	}}.run(t)
   446  }
   447  
   448  func TestH12_Head_ExplicitLen(t *testing.T) {
   449  	h12Compare{
   450  		ReqFunc: (*Client).Head,
   451  		Handler: func(w ResponseWriter, r *Request) {
   452  			if r.Method != "HEAD" {
   453  				t.Errorf("unexpected method %q", r.Method)
   454  			}
   455  			w.Header().Set("Content-Length", "1235")
   456  		},
   457  	}.run(t)
   458  }
   459  
   460  func TestH12_Head_ImplicitLen(t *testing.T) {
   461  	h12Compare{
   462  		ReqFunc: (*Client).Head,
   463  		Handler: func(w ResponseWriter, r *Request) {
   464  			if r.Method != "HEAD" {
   465  				t.Errorf("unexpected method %q", r.Method)
   466  			}
   467  			io.WriteString(w, "foo")
   468  		},
   469  	}.run(t)
   470  }
   471  
   472  func TestH12_HandlerWritesTooLittle(t *testing.T) {
   473  	h12Compare{
   474  		Handler: func(w ResponseWriter, r *Request) {
   475  			w.Header().Set("Content-Length", "3")
   476  			io.WriteString(w, "12") // one byte short
   477  		},
   478  		CheckResponse: func(proto string, res *Response) {
   479  			sr, ok := res.Body.(slurpResult)
   480  			if !ok {
   481  				t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
   482  				return
   483  			}
   484  			if sr.err != io.ErrUnexpectedEOF {
   485  				t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
   486  			}
   487  			if string(sr.body) != "12" {
   488  				t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
   489  			}
   490  		},
   491  	}.run(t)
   492  }
   493  
   494  // Tests that the HTTP/1 and HTTP/2 servers prevent handlers from
   495  // writing more than they declared. This test does not test whether
   496  // the transport deals with too much data, though, since the server
   497  // doesn't make it possible to send bogus data. For those tests, see
   498  // transport_test.go (for HTTP/1) or x/net/http2/transport_test.go
   499  // (for HTTP/2).
   500  func TestH12_HandlerWritesTooMuch(t *testing.T) {
   501  	h12Compare{
   502  		Handler: func(w ResponseWriter, r *Request) {
   503  			w.Header().Set("Content-Length", "3")
   504  			w.(Flusher).Flush()
   505  			io.WriteString(w, "123")
   506  			w.(Flusher).Flush()
   507  			n, err := io.WriteString(w, "x") // too many
   508  			if n > 0 || err == nil {
   509  				t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err)
   510  			}
   511  		},
   512  	}.run(t)
   513  }
   514  
   515  // Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip.
   516  // Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298
   517  func TestH12_AutoGzip(t *testing.T) {
   518  	h12Compare{
   519  		Handler: func(w ResponseWriter, r *Request) {
   520  			if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
   521  				t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
   522  			}
   523  			w.Header().Set("Content-Encoding", "gzip")
   524  			gz := gzip.NewWriter(w)
   525  			io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
   526  			gz.Close()
   527  		},
   528  	}.run(t)
   529  }
   530  
   531  func TestH12_AutoGzip_Disabled(t *testing.T) {
   532  	h12Compare{
   533  		Opts: []any{
   534  			func(tr *Transport) { tr.DisableCompression = true },
   535  		},
   536  		Handler: func(w ResponseWriter, r *Request) {
   537  			fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
   538  			if ae := r.Header.Get("Accept-Encoding"); ae != "" {
   539  				t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
   540  			}
   541  		},
   542  	}.run(t)
   543  }
   544  
   545  // Test304Responses verifies that 304s don't declare that they're
   546  // chunking in their response headers and aren't allowed to produce
   547  // output.
   548  func Test304Responses(t *testing.T) { run(t, test304Responses) }
   549  func test304Responses(t *testing.T, mode testMode) {
   550  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   551  		w.WriteHeader(StatusNotModified)
   552  		_, err := w.Write([]byte("illegal body"))
   553  		if err != ErrBodyNotAllowed {
   554  			t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
   555  		}
   556  	}))
   557  	defer cst.close()
   558  	res, err := cst.c.Get(cst.ts.URL)
   559  	if err != nil {
   560  		t.Fatal(err)
   561  	}
   562  	if len(res.TransferEncoding) > 0 {
   563  		t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
   564  	}
   565  	body, err := io.ReadAll(res.Body)
   566  	if err != nil {
   567  		t.Error(err)
   568  	}
   569  	if len(body) > 0 {
   570  		t.Errorf("got unexpected body %q", string(body))
   571  	}
   572  }
   573  
   574  func TestH12_ServerEmptyContentLength(t *testing.T) {
   575  	h12Compare{
   576  		Handler: func(w ResponseWriter, r *Request) {
   577  			w.Header()["Content-Type"] = []string{""}
   578  			io.WriteString(w, "<html><body>hi</body></html>")
   579  		},
   580  	}.run(t)
   581  }
   582  
   583  func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
   584  	h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
   585  }
   586  
   587  func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
   588  	h12requestContentLength(t, func() io.Reader { return nil }, 0)
   589  }
   590  
   591  func TestH12_RequestContentLength_Unknown(t *testing.T) {
   592  	h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
   593  }
   594  
   595  func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
   596  	h12Compare{
   597  		Handler: func(w ResponseWriter, r *Request) {
   598  			w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
   599  			fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
   600  		},
   601  		ReqFunc: func(c *Client, url string) (*Response, error) {
   602  			return c.Post(url, "text/plain", bodyfn())
   603  		},
   604  		CheckResponse: func(proto string, res *Response) {
   605  			if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
   606  				t.Errorf("Proto %q got length %q; want %q", proto, got, want)
   607  			}
   608  		},
   609  	}.run(t)
   610  }
   611  
   612  // Tests that closing the Request.Cancel channel also while still
   613  // reading the response body. Issue 13159.
   614  func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
   615  func testCancelRequestMidBody(t *testing.T, mode testMode) {
   616  	unblock := make(chan bool)
   617  	didFlush := make(chan bool, 1)
   618  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   619  		io.WriteString(w, "Hello")
   620  		w.(Flusher).Flush()
   621  		didFlush <- true
   622  		<-unblock
   623  		io.WriteString(w, ", world.")
   624  	}))
   625  	defer close(unblock)
   626  
   627  	req, _ := NewRequest("GET", cst.ts.URL, nil)
   628  	cancel := make(chan struct{})
   629  	req.Cancel = cancel
   630  
   631  	res, err := cst.c.Do(req)
   632  	if err != nil {
   633  		t.Fatal(err)
   634  	}
   635  	defer res.Body.Close()
   636  	<-didFlush
   637  
   638  	// Read a bit before we cancel. (Issue 13626)
   639  	// We should have "Hello" at least sitting there.
   640  	firstRead := make([]byte, 10)
   641  	n, err := res.Body.Read(firstRead)
   642  	if err != nil {
   643  		t.Fatal(err)
   644  	}
   645  	firstRead = firstRead[:n]
   646  
   647  	close(cancel)
   648  
   649  	rest, err := io.ReadAll(res.Body)
   650  	all := string(firstRead) + string(rest)
   651  	if all != "Hello" {
   652  		t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
   653  	}
   654  	if err != ExportErrRequestCanceled {
   655  		t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
   656  	}
   657  }
   658  
   659  // Tests that clients can send trailers to a server and that the server can read them.
   660  func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
   661  func testTrailersClientToServer(t *testing.T, mode testMode) {
   662  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   663  		var decl []string
   664  		for k := range r.Trailer {
   665  			decl = append(decl, k)
   666  		}
   667  		sort.Strings(decl)
   668  
   669  		slurp, err := io.ReadAll(r.Body)
   670  		if err != nil {
   671  			t.Errorf("Server reading request body: %v", err)
   672  		}
   673  		if string(slurp) != "foo" {
   674  			t.Errorf("Server read request body %q; want foo", slurp)
   675  		}
   676  		if r.Trailer == nil {
   677  			io.WriteString(w, "nil Trailer")
   678  		} else {
   679  			fmt.Fprintf(w, "decl: %v, vals: %s, %s",
   680  				decl,
   681  				r.Trailer.Get("Client-Trailer-A"),
   682  				r.Trailer.Get("Client-Trailer-B"))
   683  		}
   684  	}))
   685  
   686  	var req *Request
   687  	req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
   688  		eofReaderFunc(func() {
   689  			req.Trailer["Client-Trailer-A"] = []string{"valuea"}
   690  		}),
   691  		strings.NewReader("foo"),
   692  		eofReaderFunc(func() {
   693  			req.Trailer["Client-Trailer-B"] = []string{"valueb"}
   694  		}),
   695  	))
   696  	req.Trailer = Header{
   697  		"Client-Trailer-A": nil, //  to be set later
   698  		"Client-Trailer-B": nil, //  to be set later
   699  	}
   700  	req.ContentLength = -1
   701  	res, err := cst.c.Do(req)
   702  	if err != nil {
   703  		t.Fatal(err)
   704  	}
   705  	if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
   706  		t.Error(err)
   707  	}
   708  }
   709  
   710  // Tests that servers send trailers to a client and that the client can read them.
   711  func TestTrailersServerToClient(t *testing.T) {
   712  	run(t, func(t *testing.T, mode testMode) {
   713  		testTrailersServerToClient(t, mode, false)
   714  	})
   715  }
   716  func TestTrailersServerToClientFlush(t *testing.T) {
   717  	run(t, func(t *testing.T, mode testMode) {
   718  		testTrailersServerToClient(t, mode, true)
   719  	})
   720  }
   721  
   722  func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
   723  	const body = "Some body"
   724  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   725  		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
   726  		w.Header().Add("Trailer", "Server-Trailer-C")
   727  
   728  		io.WriteString(w, body)
   729  		if flush {
   730  			w.(Flusher).Flush()
   731  		}
   732  
   733  		// How handlers set Trailers: declare it ahead of time
   734  		// with the Trailer header, and then mutate the
   735  		// Header() of those values later, after the response
   736  		// has been written (we wrote to w above).
   737  		w.Header().Set("Server-Trailer-A", "valuea")
   738  		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
   739  		w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
   740  	}))
   741  
   742  	res, err := cst.c.Get(cst.ts.URL)
   743  	if err != nil {
   744  		t.Fatal(err)
   745  	}
   746  
   747  	wantHeader := Header{
   748  		"Content-Type": {"text/plain; charset=utf-8"},
   749  	}
   750  	wantLen := -1
   751  	if mode == http2Mode && !flush {
   752  		// In HTTP/1.1, any use of trailers forces HTTP/1.1
   753  		// chunking and a flush at the first write. That's
   754  		// unnecessary with HTTP/2's framing, so the server
   755  		// is able to calculate the length while still sending
   756  		// trailers afterwards.
   757  		wantLen = len(body)
   758  		wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
   759  	}
   760  	if res.ContentLength != int64(wantLen) {
   761  		t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
   762  	}
   763  
   764  	delete(res.Header, "Date") // irrelevant for test
   765  	if !reflect.DeepEqual(res.Header, wantHeader) {
   766  		t.Errorf("Header = %v; want %v", res.Header, wantHeader)
   767  	}
   768  
   769  	if got, want := res.Trailer, (Header{
   770  		"Server-Trailer-A": nil,
   771  		"Server-Trailer-B": nil,
   772  		"Server-Trailer-C": nil,
   773  	}); !reflect.DeepEqual(got, want) {
   774  		t.Errorf("Trailer before body read = %v; want %v", got, want)
   775  	}
   776  
   777  	if err := wantBody(res, nil, body); err != nil {
   778  		t.Fatal(err)
   779  	}
   780  
   781  	if got, want := res.Trailer, (Header{
   782  		"Server-Trailer-A": {"valuea"},
   783  		"Server-Trailer-B": nil,
   784  		"Server-Trailer-C": {"valuec"},
   785  	}); !reflect.DeepEqual(got, want) {
   786  		t.Errorf("Trailer after body read = %v; want %v", got, want)
   787  	}
   788  }
   789  
   790  // Don't allow a Body.Read after Body.Close. Issue 13648.
   791  func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
   792  func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
   793  	const body = "Some body"
   794  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   795  		io.WriteString(w, body)
   796  	}))
   797  	res, err := cst.c.Get(cst.ts.URL)
   798  	if err != nil {
   799  		t.Fatal(err)
   800  	}
   801  	res.Body.Close()
   802  	data, err := io.ReadAll(res.Body)
   803  	if len(data) != 0 || err == nil {
   804  		t.Fatalf("ReadAll returned %q, %v; want error", data, err)
   805  	}
   806  }
   807  
   808  func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
   809  func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
   810  	const reqBody = "some request body"
   811  	const resBody = "some response body"
   812  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   813  		var wg sync.WaitGroup
   814  		wg.Add(2)
   815  		didRead := make(chan bool, 1)
   816  		// Read in one goroutine.
   817  		go func() {
   818  			defer wg.Done()
   819  			data, err := io.ReadAll(r.Body)
   820  			if string(data) != reqBody {
   821  				t.Errorf("Handler read %q; want %q", data, reqBody)
   822  			}
   823  			if err != nil {
   824  				t.Errorf("Handler Read: %v", err)
   825  			}
   826  			didRead <- true
   827  		}()
   828  		// Write in another goroutine.
   829  		go func() {
   830  			defer wg.Done()
   831  			if mode != http2Mode {
   832  				// our HTTP/1 implementation intentionally
   833  				// doesn't permit writes during read (mostly
   834  				// due to it being undefined); if that is ever
   835  				// relaxed, change this.
   836  				<-didRead
   837  			}
   838  			io.WriteString(w, resBody)
   839  		}()
   840  		wg.Wait()
   841  	}))
   842  	req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
   843  	req.Header.Add("Expect", "100-continue") // just to complicate things
   844  	res, err := cst.c.Do(req)
   845  	if err != nil {
   846  		t.Fatal(err)
   847  	}
   848  	data, err := io.ReadAll(res.Body)
   849  	defer res.Body.Close()
   850  	if err != nil {
   851  		t.Fatal(err)
   852  	}
   853  	if string(data) != resBody {
   854  		t.Errorf("read %q; want %q", data, resBody)
   855  	}
   856  }
   857  
   858  func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
   859  func testConnectRequest(t *testing.T, mode testMode) {
   860  	gotc := make(chan *Request, 1)
   861  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   862  		gotc <- r
   863  	}))
   864  
   865  	u, err := url.Parse(cst.ts.URL)
   866  	if err != nil {
   867  		t.Fatal(err)
   868  	}
   869  
   870  	tests := []struct {
   871  		req  *Request
   872  		want string
   873  	}{
   874  		{
   875  			req: &Request{
   876  				Method: "CONNECT",
   877  				Header: Header{},
   878  				URL:    u,
   879  			},
   880  			want: u.Host,
   881  		},
   882  		{
   883  			req: &Request{
   884  				Method: "CONNECT",
   885  				Header: Header{},
   886  				URL:    u,
   887  				Host:   "example.com:123",
   888  			},
   889  			want: "example.com:123",
   890  		},
   891  	}
   892  
   893  	for i, tt := range tests {
   894  		res, err := cst.c.Do(tt.req)
   895  		if err != nil {
   896  			t.Errorf("%d. RoundTrip = %v", i, err)
   897  			continue
   898  		}
   899  		res.Body.Close()
   900  		req := <-gotc
   901  		if req.Method != "CONNECT" {
   902  			t.Errorf("method = %q; want CONNECT", req.Method)
   903  		}
   904  		if req.Host != tt.want {
   905  			t.Errorf("Host = %q; want %q", req.Host, tt.want)
   906  		}
   907  		if req.URL.Host != tt.want {
   908  			t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
   909  		}
   910  	}
   911  }
   912  
   913  func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
   914  func testTransportUserAgent(t *testing.T, mode testMode) {
   915  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   916  		fmt.Fprintf(w, "%q", r.Header["User-Agent"])
   917  	}))
   918  
   919  	either := func(a, b string) string {
   920  		if mode == http2Mode {
   921  			return b
   922  		}
   923  		return a
   924  	}
   925  
   926  	tests := []struct {
   927  		setup func(*Request)
   928  		want  string
   929  	}{
   930  		{
   931  			func(r *Request) {},
   932  			either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
   933  		},
   934  		{
   935  			func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
   936  			`["foo/1.2.3"]`,
   937  		},
   938  		{
   939  			func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
   940  			`["single"]`,
   941  		},
   942  		{
   943  			func(r *Request) { r.Header.Set("User-Agent", "") },
   944  			`[]`,
   945  		},
   946  		{
   947  			func(r *Request) { r.Header["User-Agent"] = nil },
   948  			`[]`,
   949  		},
   950  	}
   951  	for i, tt := range tests {
   952  		req, _ := NewRequest("GET", cst.ts.URL, nil)
   953  		tt.setup(req)
   954  		res, err := cst.c.Do(req)
   955  		if err != nil {
   956  			t.Errorf("%d. RoundTrip = %v", i, err)
   957  			continue
   958  		}
   959  		slurp, err := io.ReadAll(res.Body)
   960  		res.Body.Close()
   961  		if err != nil {
   962  			t.Errorf("%d. read body = %v", i, err)
   963  			continue
   964  		}
   965  		if string(slurp) != tt.want {
   966  			t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
   967  		}
   968  	}
   969  }
   970  
   971  func TestStarRequestMethod(t *testing.T) {
   972  	for _, method := range []string{"FOO", "OPTIONS"} {
   973  		t.Run(method, func(t *testing.T) {
   974  			run(t, func(t *testing.T, mode testMode) {
   975  				testStarRequest(t, method, mode)
   976  			})
   977  		})
   978  	}
   979  }
   980  func testStarRequest(t *testing.T, method string, mode testMode) {
   981  	gotc := make(chan *Request, 1)
   982  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   983  		w.Header().Set("foo", "bar")
   984  		gotc <- r
   985  		w.(Flusher).Flush()
   986  	}))
   987  
   988  	u, err := url.Parse(cst.ts.URL)
   989  	if err != nil {
   990  		t.Fatal(err)
   991  	}
   992  	u.Path = "*"
   993  
   994  	req := &Request{
   995  		Method: method,
   996  		Header: Header{},
   997  		URL:    u,
   998  	}
   999  
  1000  	res, err := cst.c.Do(req)
  1001  	if err != nil {
  1002  		t.Fatalf("RoundTrip = %v", err)
  1003  	}
  1004  	res.Body.Close()
  1005  
  1006  	wantFoo := "bar"
  1007  	wantLen := int64(-1)
  1008  	if method == "OPTIONS" {
  1009  		wantFoo = ""
  1010  		wantLen = 0
  1011  	}
  1012  	if res.StatusCode != 200 {
  1013  		t.Errorf("status code = %v; want %d", res.Status, 200)
  1014  	}
  1015  	if res.ContentLength != wantLen {
  1016  		t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
  1017  	}
  1018  	if got := res.Header.Get("foo"); got != wantFoo {
  1019  		t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
  1020  	}
  1021  	select {
  1022  	case req = <-gotc:
  1023  	default:
  1024  		req = nil
  1025  	}
  1026  	if req == nil {
  1027  		if method != "OPTIONS" {
  1028  			t.Fatalf("handler never got request")
  1029  		}
  1030  		return
  1031  	}
  1032  	if req.Method != method {
  1033  		t.Errorf("method = %q; want %q", req.Method, method)
  1034  	}
  1035  	if req.URL.Path != "*" {
  1036  		t.Errorf("URL.Path = %q; want *", req.URL.Path)
  1037  	}
  1038  	if req.RequestURI != "*" {
  1039  		t.Errorf("RequestURI = %q; want *", req.RequestURI)
  1040  	}
  1041  }
  1042  
  1043  // Issue 13957
  1044  func TestTransportDiscardsUnneededConns(t *testing.T) {
  1045  	run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
  1046  }
  1047  func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
  1048  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1049  		fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
  1050  	}))
  1051  	defer cst.close()
  1052  
  1053  	var numOpen, numClose int32 // atomic
  1054  
  1055  	tlsConfig := &tls.Config{InsecureSkipVerify: true}
  1056  	tr := &Transport{
  1057  		TLSClientConfig: tlsConfig,
  1058  		DialTLS: func(_, addr string) (net.Conn, error) {
  1059  			time.Sleep(10 * time.Millisecond)
  1060  			rc, err := net.Dial("tcp", addr)
  1061  			if err != nil {
  1062  				return nil, err
  1063  			}
  1064  			atomic.AddInt32(&numOpen, 1)
  1065  			c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
  1066  			return tls.Client(c, tlsConfig), nil
  1067  		},
  1068  	}
  1069  	if err := ExportHttp2ConfigureTransport(tr); err != nil {
  1070  		t.Fatal(err)
  1071  	}
  1072  	defer tr.CloseIdleConnections()
  1073  
  1074  	c := &Client{Transport: tr}
  1075  
  1076  	const N = 10
  1077  	gotBody := make(chan string, N)
  1078  	var wg sync.WaitGroup
  1079  	for i := 0; i < N; i++ {
  1080  		wg.Add(1)
  1081  		go func() {
  1082  			defer wg.Done()
  1083  			resp, err := c.Get(cst.ts.URL)
  1084  			if err != nil {
  1085  				// Try to work around spurious connection reset on loaded system.
  1086  				// See golang.org/issue/33585 and golang.org/issue/36797.
  1087  				time.Sleep(10 * time.Millisecond)
  1088  				resp, err = c.Get(cst.ts.URL)
  1089  				if err != nil {
  1090  					t.Errorf("Get: %v", err)
  1091  					return
  1092  				}
  1093  			}
  1094  			defer resp.Body.Close()
  1095  			slurp, err := io.ReadAll(resp.Body)
  1096  			if err != nil {
  1097  				t.Error(err)
  1098  			}
  1099  			gotBody <- string(slurp)
  1100  		}()
  1101  	}
  1102  	wg.Wait()
  1103  	close(gotBody)
  1104  
  1105  	var last string
  1106  	for got := range gotBody {
  1107  		if last == "" {
  1108  			last = got
  1109  			continue
  1110  		}
  1111  		if got != last {
  1112  			t.Errorf("Response body changed: %q -> %q", last, got)
  1113  		}
  1114  	}
  1115  
  1116  	var open, close int32
  1117  	for i := 0; i < 150; i++ {
  1118  		open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
  1119  		if open < 1 {
  1120  			t.Fatalf("open = %d; want at least", open)
  1121  		}
  1122  		if close == open-1 {
  1123  			// Success
  1124  			return
  1125  		}
  1126  		time.Sleep(10 * time.Millisecond)
  1127  	}
  1128  	t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
  1129  }
  1130  
  1131  // tests that Transport doesn't retain a pointer to the provided request.
  1132  func TestTransportGCRequest(t *testing.T) {
  1133  	run(t, func(t *testing.T, mode testMode) {
  1134  		t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
  1135  		t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
  1136  	})
  1137  }
  1138  func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
  1139  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1140  		io.ReadAll(r.Body)
  1141  		if body {
  1142  			io.WriteString(w, "Hello.")
  1143  		}
  1144  	}))
  1145  
  1146  	didGC := make(chan struct{})
  1147  	(func() {
  1148  		body := strings.NewReader("some body")
  1149  		req, _ := NewRequest("POST", cst.ts.URL, body)
  1150  		runtime.SetFinalizer(req, func(*Request) { close(didGC) })
  1151  		res, err := cst.c.Do(req)
  1152  		if err != nil {
  1153  			t.Fatal(err)
  1154  		}
  1155  		if _, err := io.ReadAll(res.Body); err != nil {
  1156  			t.Fatal(err)
  1157  		}
  1158  		if err := res.Body.Close(); err != nil {
  1159  			t.Fatal(err)
  1160  		}
  1161  	})()
  1162  	timeout := time.NewTimer(5 * time.Second)
  1163  	defer timeout.Stop()
  1164  	for {
  1165  		select {
  1166  		case <-didGC:
  1167  			return
  1168  		case <-time.After(100 * time.Millisecond):
  1169  			runtime.GC()
  1170  		case <-timeout.C:
  1171  			t.Fatal("never saw GC of request")
  1172  		}
  1173  	}
  1174  }
  1175  
  1176  func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
  1177  func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
  1178  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1179  		fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
  1180  	}), optQuietLog)
  1181  	cst.tr.DisableKeepAlives = true
  1182  
  1183  	tests := []struct {
  1184  		key, val string
  1185  		ok       bool
  1186  	}{
  1187  		{"Foo", "capital-key", true}, // verify h2 allows capital keys
  1188  		{"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed
  1189  		{"Foo", "two\nlines", false}, // \n byte in value not allowed
  1190  		{"bogus\nkey", "v", false},   // \n byte also not allowed in key
  1191  		{"A space", "v", false},      // spaces in keys not allowed
  1192  		{"имя", "v", false},          // key must be ascii
  1193  		{"name", "валю", true},       // value may be non-ascii
  1194  		{"", "v", false},             // key must be non-empty
  1195  		{"k", "", true},              // value may be empty
  1196  	}
  1197  	for _, tt := range tests {
  1198  		dialedc := make(chan bool, 1)
  1199  		cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
  1200  			dialedc <- true
  1201  			return net.Dial(netw, addr)
  1202  		}
  1203  		req, _ := NewRequest("GET", cst.ts.URL, nil)
  1204  		req.Header[tt.key] = []string{tt.val}
  1205  		res, err := cst.c.Do(req)
  1206  		var body []byte
  1207  		if err == nil {
  1208  			body, _ = io.ReadAll(res.Body)
  1209  			res.Body.Close()
  1210  		}
  1211  		var dialed bool
  1212  		select {
  1213  		case <-dialedc:
  1214  			dialed = true
  1215  		default:
  1216  		}
  1217  
  1218  		if !tt.ok && dialed {
  1219  			t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
  1220  		} else if (err == nil) != tt.ok {
  1221  			t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
  1222  		}
  1223  	}
  1224  }
  1225  
  1226  func TestInterruptWithPanic(t *testing.T) {
  1227  	run(t, func(t *testing.T, mode testMode) {
  1228  		t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
  1229  		t.Run("nil", func(t *testing.T) { testInterruptWithPanic(t, mode, nil) })
  1230  		t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
  1231  	})
  1232  }
  1233  func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
  1234  	const msg = "hello"
  1235  
  1236  	testDone := make(chan struct{})
  1237  	defer close(testDone)
  1238  
  1239  	var errorLog lockedBytesBuffer
  1240  	gotHeaders := make(chan bool, 1)
  1241  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1242  		io.WriteString(w, msg)
  1243  		w.(Flusher).Flush()
  1244  
  1245  		select {
  1246  		case <-gotHeaders:
  1247  		case <-testDone:
  1248  		}
  1249  		panic(panicValue)
  1250  	}), func(ts *httptest.Server) {
  1251  		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
  1252  	})
  1253  	res, err := cst.c.Get(cst.ts.URL)
  1254  	if err != nil {
  1255  		t.Fatal(err)
  1256  	}
  1257  	gotHeaders <- true
  1258  	defer res.Body.Close()
  1259  	slurp, err := io.ReadAll(res.Body)
  1260  	if string(slurp) != msg {
  1261  		t.Errorf("client read %q; want %q", slurp, msg)
  1262  	}
  1263  	if err == nil {
  1264  		t.Errorf("client read all successfully; want some error")
  1265  	}
  1266  	logOutput := func() string {
  1267  		errorLog.Lock()
  1268  		defer errorLog.Unlock()
  1269  		return errorLog.String()
  1270  	}
  1271  	wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
  1272  
  1273  	if err := waitErrCondition(5*time.Second, 10*time.Millisecond, func() error {
  1274  		gotLog := logOutput()
  1275  		if !wantStackLogged {
  1276  			if gotLog == "" {
  1277  				return nil
  1278  			}
  1279  			return fmt.Errorf("want no log output; got: %s", gotLog)
  1280  		}
  1281  		if gotLog == "" {
  1282  			return fmt.Errorf("wanted a stack trace logged; got nothing")
  1283  		}
  1284  		if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
  1285  			return fmt.Errorf("output doesn't look like a panic stack trace. Got: %s", gotLog)
  1286  		}
  1287  		return nil
  1288  	}); err != nil {
  1289  		t.Fatal(err)
  1290  	}
  1291  }
  1292  
  1293  type lockedBytesBuffer struct {
  1294  	sync.Mutex
  1295  	bytes.Buffer
  1296  }
  1297  
  1298  func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
  1299  	b.Lock()
  1300  	defer b.Unlock()
  1301  	return b.Buffer.Write(p)
  1302  }
  1303  
  1304  // Issue 15366
  1305  func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
  1306  	h12Compare{
  1307  		Handler: func(w ResponseWriter, r *Request) {
  1308  			h := w.Header()
  1309  			h.Set("Content-Encoding", "gzip")
  1310  			h.Set("Content-Length", "23")
  1311  			io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
  1312  		},
  1313  		EarlyCheckResponse: func(proto string, res *Response) {
  1314  			if !res.Uncompressed {
  1315  				t.Errorf("%s: expected Uncompressed to be set", proto)
  1316  			}
  1317  			dump, err := httputil.DumpResponse(res, true)
  1318  			if err != nil {
  1319  				t.Errorf("%s: DumpResponse: %v", proto, err)
  1320  				return
  1321  			}
  1322  			if strings.Contains(string(dump), "Connection: close") {
  1323  				t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
  1324  			}
  1325  			if !strings.Contains(string(dump), "FOO") {
  1326  				t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
  1327  			}
  1328  		},
  1329  	}.run(t)
  1330  }
  1331  
  1332  // Issue 14607
  1333  func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
  1334  func testCloseIdleConnections(t *testing.T, mode testMode) {
  1335  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1336  		w.Header().Set("X-Addr", r.RemoteAddr)
  1337  	}))
  1338  	get := func() string {
  1339  		res, err := cst.c.Get(cst.ts.URL)
  1340  		if err != nil {
  1341  			t.Fatal(err)
  1342  		}
  1343  		res.Body.Close()
  1344  		v := res.Header.Get("X-Addr")
  1345  		if v == "" {
  1346  			t.Fatal("didn't get X-Addr")
  1347  		}
  1348  		return v
  1349  	}
  1350  	a1 := get()
  1351  	cst.tr.CloseIdleConnections()
  1352  	a2 := get()
  1353  	if a1 == a2 {
  1354  		t.Errorf("didn't close connection")
  1355  	}
  1356  }
  1357  
  1358  type noteCloseConn struct {
  1359  	net.Conn
  1360  	closeFunc func()
  1361  }
  1362  
  1363  func (x noteCloseConn) Close() error {
  1364  	x.closeFunc()
  1365  	return x.Conn.Close()
  1366  }
  1367  
  1368  type testErrorReader struct{ t *testing.T }
  1369  
  1370  func (r testErrorReader) Read(p []byte) (n int, err error) {
  1371  	r.t.Error("unexpected Read call")
  1372  	return 0, io.EOF
  1373  }
  1374  
  1375  func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
  1376  func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
  1377  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1378  		w.WriteHeader(StatusUnauthorized)
  1379  	}))
  1380  
  1381  	// Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it.
  1382  	cst.tr.ExpectContinueTimeout = 10 * time.Second
  1383  
  1384  	req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
  1385  	if err != nil {
  1386  		t.Fatal(err)
  1387  	}
  1388  	req.ContentLength = 0 // so transport is tempted to sniff it
  1389  	req.Header.Set("Expect", "100-continue")
  1390  	res, err := cst.tr.RoundTrip(req)
  1391  	if err != nil {
  1392  		t.Fatal(err)
  1393  	}
  1394  	defer res.Body.Close()
  1395  	if res.StatusCode != StatusUnauthorized {
  1396  		t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
  1397  	}
  1398  }
  1399  
  1400  func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
  1401  func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
  1402  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1403  		w.Header().Set("Foo", "Bar")
  1404  		w.Header().Set("Trailer:Foo", "Baz")
  1405  		w.(Flusher).Flush()
  1406  		w.Header().Add("Trailer:Foo", "Baz2")
  1407  		w.Header().Set("Trailer:Bar", "Quux")
  1408  	}))
  1409  	res, err := cst.c.Get(cst.ts.URL)
  1410  	if err != nil {
  1411  		t.Fatal(err)
  1412  	}
  1413  	if _, err := io.Copy(io.Discard, res.Body); err != nil {
  1414  		t.Fatal(err)
  1415  	}
  1416  	res.Body.Close()
  1417  	delete(res.Header, "Date")
  1418  	delete(res.Header, "Content-Type")
  1419  
  1420  	if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
  1421  		t.Errorf("Header = %#v; want %#v", res.Header, want)
  1422  	}
  1423  	if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
  1424  		t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
  1425  	}
  1426  }
  1427  
  1428  func TestBadResponseAfterReadingBody(t *testing.T) {
  1429  	run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
  1430  }
  1431  func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
  1432  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1433  		_, err := io.Copy(io.Discard, r.Body)
  1434  		if err != nil {
  1435  			t.Fatal(err)
  1436  		}
  1437  		c, _, err := w.(Hijacker).Hijack()
  1438  		if err != nil {
  1439  			t.Fatal(err)
  1440  		}
  1441  		defer c.Close()
  1442  		fmt.Fprintln(c, "some bogus crap")
  1443  	}))
  1444  
  1445  	closes := 0
  1446  	res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
  1447  	if err == nil {
  1448  		res.Body.Close()
  1449  		t.Fatal("expected an error to be returned from Post")
  1450  	}
  1451  	if closes != 1 {
  1452  		t.Errorf("closes = %d; want 1", closes)
  1453  	}
  1454  }
  1455  
  1456  func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
  1457  func testWriteHeader0(t *testing.T, mode testMode) {
  1458  	gotpanic := make(chan bool, 1)
  1459  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1460  		defer close(gotpanic)
  1461  		defer func() {
  1462  			if e := recover(); e != nil {
  1463  				got := fmt.Sprintf("%T, %v", e, e)
  1464  				want := "string, invalid WriteHeader code 0"
  1465  				if got != want {
  1466  					t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
  1467  				}
  1468  				gotpanic <- true
  1469  
  1470  				// Set an explicit 503. This also tests that the WriteHeader call panics
  1471  				// before it recorded that an explicit value was set and that bogus
  1472  				// value wasn't stuck.
  1473  				w.WriteHeader(503)
  1474  			}
  1475  		}()
  1476  		w.WriteHeader(0)
  1477  	}))
  1478  	res, err := cst.c.Get(cst.ts.URL)
  1479  	if err != nil {
  1480  		t.Fatal(err)
  1481  	}
  1482  	if res.StatusCode != 503 {
  1483  		t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
  1484  	}
  1485  	if !<-gotpanic {
  1486  		t.Error("expected panic in handler")
  1487  	}
  1488  }
  1489  
  1490  // Issue 23010: don't be super strict checking WriteHeader's code if
  1491  // it's not even valid to call WriteHeader then anyway.
  1492  func TestWriteHeaderNoCodeCheck(t *testing.T) {
  1493  	run(t, func(t *testing.T, mode testMode) {
  1494  		testWriteHeaderAfterWrite(t, mode, false)
  1495  	})
  1496  }
  1497  func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
  1498  	testWriteHeaderAfterWrite(t, http1Mode, true)
  1499  }
  1500  func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
  1501  	var errorLog lockedBytesBuffer
  1502  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1503  		if hijack {
  1504  			conn, _, _ := w.(Hijacker).Hijack()
  1505  			defer conn.Close()
  1506  			conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
  1507  			w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
  1508  			conn.Write([]byte("bar"))
  1509  			return
  1510  		}
  1511  		io.WriteString(w, "foo")
  1512  		w.(Flusher).Flush()
  1513  		w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
  1514  		io.WriteString(w, "bar")
  1515  	}), func(ts *httptest.Server) {
  1516  		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
  1517  	})
  1518  	res, err := cst.c.Get(cst.ts.URL)
  1519  	if err != nil {
  1520  		t.Fatal(err)
  1521  	}
  1522  	defer res.Body.Close()
  1523  	body, err := io.ReadAll(res.Body)
  1524  	if err != nil {
  1525  		t.Fatal(err)
  1526  	}
  1527  	if got, want := string(body), "foobar"; got != want {
  1528  		t.Errorf("got = %q; want %q", got, want)
  1529  	}
  1530  
  1531  	// Also check the stderr output:
  1532  	if mode == http2Mode {
  1533  		// TODO: also emit this log message for HTTP/2?
  1534  		// We historically haven't, so don't check.
  1535  		return
  1536  	}
  1537  	gotLog := strings.TrimSpace(errorLog.String())
  1538  	wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
  1539  	if hijack {
  1540  		wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
  1541  	}
  1542  	if !strings.HasPrefix(gotLog, wantLog) {
  1543  		t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
  1544  	}
  1545  }
  1546  
  1547  func TestBidiStreamReverseProxy(t *testing.T) {
  1548  	run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
  1549  }
  1550  func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
  1551  	backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1552  		if _, err := io.Copy(w, r.Body); err != nil {
  1553  			log.Printf("bidi backend copy: %v", err)
  1554  		}
  1555  	}))
  1556  
  1557  	backURL, err := url.Parse(backend.ts.URL)
  1558  	if err != nil {
  1559  		t.Fatal(err)
  1560  	}
  1561  	rp := httputil.NewSingleHostReverseProxy(backURL)
  1562  	rp.Transport = backend.tr
  1563  	proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1564  		rp.ServeHTTP(w, r)
  1565  	}))
  1566  
  1567  	bodyRes := make(chan any, 1) // error or hash.Hash
  1568  	pr, pw := io.Pipe()
  1569  	req, _ := NewRequest("PUT", proxy.ts.URL, pr)
  1570  	const size = 4 << 20
  1571  	go func() {
  1572  		h := sha1.New()
  1573  		_, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
  1574  		go pw.Close()
  1575  		if err != nil {
  1576  			bodyRes <- err
  1577  		} else {
  1578  			bodyRes <- h
  1579  		}
  1580  	}()
  1581  	res, err := backend.c.Do(req)
  1582  	if err != nil {
  1583  		t.Fatal(err)
  1584  	}
  1585  	defer res.Body.Close()
  1586  	hgot := sha1.New()
  1587  	n, err := io.Copy(hgot, res.Body)
  1588  	if err != nil {
  1589  		t.Fatal(err)
  1590  	}
  1591  	if n != size {
  1592  		t.Fatalf("got %d bytes; want %d", n, size)
  1593  	}
  1594  	select {
  1595  	case v := <-bodyRes:
  1596  		switch v := v.(type) {
  1597  		default:
  1598  			t.Fatalf("body copy: %v", err)
  1599  		case hash.Hash:
  1600  			if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
  1601  				t.Errorf("written bytes didn't match received bytes")
  1602  			}
  1603  		}
  1604  	case <-time.After(10 * time.Second):
  1605  		t.Fatal("timeout")
  1606  	}
  1607  
  1608  }
  1609  
  1610  // Always use HTTP/1.1 for WebSocket upgrades.
  1611  func TestH12_WebSocketUpgrade(t *testing.T) {
  1612  	h12Compare{
  1613  		Handler: func(w ResponseWriter, r *Request) {
  1614  			h := w.Header()
  1615  			h.Set("Foo", "bar")
  1616  		},
  1617  		ReqFunc: func(c *Client, url string) (*Response, error) {
  1618  			req, _ := NewRequest("GET", url, nil)
  1619  			req.Header.Set("Connection", "Upgrade")
  1620  			req.Header.Set("Upgrade", "WebSocket")
  1621  			return c.Do(req)
  1622  		},
  1623  		EarlyCheckResponse: func(proto string, res *Response) {
  1624  			if res.Proto != "HTTP/1.1" {
  1625  				t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
  1626  			}
  1627  			res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0
  1628  		},
  1629  	}.run(t)
  1630  }
  1631  
  1632  func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
  1633  func testIdentityTransferEncoding(t *testing.T, mode testMode) {
  1634  	const body = "body"
  1635  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1636  		gotBody, _ := io.ReadAll(r.Body)
  1637  		if got, want := string(gotBody), body; got != want {
  1638  			t.Errorf("got request body = %q; want %q", got, want)
  1639  		}
  1640  		w.Header().Set("Transfer-Encoding", "identity")
  1641  		w.WriteHeader(StatusOK)
  1642  		w.(Flusher).Flush()
  1643  		io.WriteString(w, body)
  1644  	}))
  1645  	req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
  1646  	res, err := cst.c.Do(req)
  1647  	if err != nil {
  1648  		t.Fatal(err)
  1649  	}
  1650  	defer res.Body.Close()
  1651  	gotBody, err := io.ReadAll(res.Body)
  1652  	if err != nil {
  1653  		t.Fatal(err)
  1654  	}
  1655  	if got, want := string(gotBody), body; got != want {
  1656  		t.Errorf("got response body = %q; want %q", got, want)
  1657  	}
  1658  }
  1659  
  1660  func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
  1661  func testEarlyHintsRequest(t *testing.T, mode testMode) {
  1662  	var wg sync.WaitGroup
  1663  	wg.Add(1)
  1664  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1665  		h := w.Header()
  1666  
  1667  		h.Add("Content-Length", "123") // must be ignored
  1668  		h.Add("Link", "</style.css>; rel=preload; as=style")
  1669  		h.Add("Link", "</script.js>; rel=preload; as=script")
  1670  		w.WriteHeader(StatusEarlyHints)
  1671  
  1672  		wg.Wait()
  1673  
  1674  		h.Add("Link", "</foo.js>; rel=preload; as=script")
  1675  		w.WriteHeader(StatusEarlyHints)
  1676  
  1677  		w.Write([]byte("Hello"))
  1678  	}))
  1679  
  1680  	checkLinkHeaders := func(t *testing.T, expected, got []string) {
  1681  		t.Helper()
  1682  
  1683  		if len(expected) != len(got) {
  1684  			t.Errorf("got %d expected %d", len(got), len(expected))
  1685  		}
  1686  
  1687  		for i := range expected {
  1688  			if expected[i] != got[i] {
  1689  				t.Errorf("got %q expected %q", got[i], expected[i])
  1690  			}
  1691  		}
  1692  	}
  1693  
  1694  	checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
  1695  		t.Helper()
  1696  
  1697  		for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
  1698  			if v, ok := header[h]; ok {
  1699  				t.Errorf("%s is %q; must not be sent", h, v)
  1700  			}
  1701  		}
  1702  	}
  1703  
  1704  	var respCounter uint8
  1705  	trace := &httptrace.ClientTrace{
  1706  		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  1707  			switch respCounter {
  1708  			case 0:
  1709  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
  1710  				checkExcludedHeaders(t, header)
  1711  
  1712  				wg.Done()
  1713  			case 1:
  1714  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
  1715  				checkExcludedHeaders(t, header)
  1716  
  1717  			default:
  1718  				t.Error("Unexpected 1xx response")
  1719  			}
  1720  
  1721  			respCounter++
  1722  
  1723  			return nil
  1724  		},
  1725  	}
  1726  	req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
  1727  
  1728  	res, err := cst.c.Do(req)
  1729  	if err != nil {
  1730  		t.Fatal(err)
  1731  	}
  1732  	defer res.Body.Close()
  1733  
  1734  	checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
  1735  	if cl := res.Header.Get("Content-Length"); cl != "123" {
  1736  		t.Errorf("Content-Length is %q; want 123", cl)
  1737  	}
  1738  
  1739  	body, _ := io.ReadAll(res.Body)
  1740  	if string(body) != "Hello" {
  1741  		t.Errorf("Read body %q; want Hello", body)
  1742  	}
  1743  }