github.com/ooni/oohttp@v0.7.2/clientserver_test.go (about)

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode.
     6  
     7  package http_test
     8  
     9  import (
    10  	"bytes"
    11  	"compress/gzip"
    12  	"context"
    13  	"crypto/rand"
    14  	"crypto/sha1"
    15  	"crypto/tls"
    16  	"fmt"
    17  	"hash"
    18  	"io"
    19  	"log"
    20  	"net"
    21  	"net/textproto"
    22  	"net/url"
    23  	"os"
    24  	"reflect"
    25  	"runtime"
    26  	"sort"
    27  	"strings"
    28  	"sync"
    29  	"sync/atomic"
    30  	"testing"
    31  	"time"
    32  
    33  	. "github.com/ooni/oohttp"
    34  	httptest "github.com/ooni/oohttp/httptest"
    35  	httptrace "github.com/ooni/oohttp/httptrace"
    36  	httputil "github.com/ooni/oohttp/httputil"
    37  )
    38  
    39  type testMode string
    40  
    41  const (
    42  	http1Mode  = testMode("h1")     // HTTP/1.1
    43  	https1Mode = testMode("https1") // HTTPS/1.1
    44  	http2Mode  = testMode("h2")     // HTTP/2
    45  )
    46  
    47  type testNotParallelOpt struct{}
    48  
    49  var (
    50  	testNotParallel = testNotParallelOpt{}
    51  )
    52  
    53  type TBRun[T any] interface {
    54  	testing.TB
    55  	Run(string, func(T)) bool
    56  }
    57  
    58  // run runs a client/server test in a variety of test configurations.
    59  //
    60  // Tests execute in HTTP/1.1 and HTTP/2 modes by default.
    61  // To run in a different set of configurations, pass a []testMode option.
    62  //
    63  // Tests call t.Parallel() by default.
    64  // To disable parallel execution, pass the testNotParallel option.
    65  func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
    66  	t.Helper()
    67  	modes := []testMode{http1Mode, http2Mode}
    68  	parallel := true
    69  	for _, opt := range opts {
    70  		switch opt := opt.(type) {
    71  		case []testMode:
    72  			modes = opt
    73  		case testNotParallelOpt:
    74  			parallel = false
    75  		default:
    76  			t.Fatalf("unknown option type %T", opt)
    77  		}
    78  	}
    79  	if t, ok := any(t).(*testing.T); ok && parallel {
    80  		setParallel(t)
    81  	}
    82  	for _, mode := range modes {
    83  		t.Run(string(mode), func(t T) {
    84  			t.Helper()
    85  			if t, ok := any(t).(*testing.T); ok && parallel {
    86  				setParallel(t)
    87  			}
    88  			t.Cleanup(func() {
    89  				afterTest(t)
    90  			})
    91  			f(t, mode)
    92  		})
    93  	}
    94  }
    95  
    96  type clientServerTest struct {
    97  	t  testing.TB
    98  	h2 bool
    99  	h  Handler
   100  	ts *httptest.Server
   101  	tr *Transport
   102  	c  *Client
   103  }
   104  
   105  func (t *clientServerTest) close() {
   106  	t.tr.CloseIdleConnections()
   107  	t.ts.Close()
   108  }
   109  
   110  func (t *clientServerTest) getURL(u string) string {
   111  	res, err := t.c.Get(u)
   112  	if err != nil {
   113  		t.t.Fatal(err)
   114  	}
   115  	defer res.Body.Close()
   116  	slurp, err := io.ReadAll(res.Body)
   117  	if err != nil {
   118  		t.t.Fatal(err)
   119  	}
   120  	return string(slurp)
   121  }
   122  
   123  func (t *clientServerTest) scheme() string {
   124  	if t.h2 {
   125  		return "https"
   126  	}
   127  	return "http"
   128  }
   129  
   130  var optQuietLog = func(ts *httptest.Server) {
   131  	ts.Config.ErrorLog = quietLog
   132  }
   133  
   134  func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
   135  	return func(ts *httptest.Server) {
   136  		ts.Config.ErrorLog = lg
   137  	}
   138  }
   139  
   140  // newClientServerTest creates and starts an httptest.Server.
   141  //
   142  // The mode parameter selects the implementation to test:
   143  // HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use
   144  // the 'run' function, which will start a subtests for each tested mode.
   145  //
   146  // The vararg opts parameter can include functions to configure the
   147  // test server or transport.
   148  //
   149  //	func(*httptest.Server) // run before starting the server
   150  //	func(*http.Transport)
   151  func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
   152  	if mode == http2Mode {
   153  		CondSkipHTTP2(t)
   154  	}
   155  	cst := &clientServerTest{
   156  		t:  t,
   157  		h2: mode == http2Mode,
   158  		h:  h,
   159  	}
   160  	cst.ts = httptest.NewUnstartedServer(h)
   161  
   162  	var transportFuncs []func(*Transport)
   163  	for _, opt := range opts {
   164  		switch opt := opt.(type) {
   165  		case func(*Transport):
   166  			transportFuncs = append(transportFuncs, opt)
   167  		case func(*httptest.Server):
   168  			opt(cst.ts)
   169  		default:
   170  			t.Fatalf("unhandled option type %T", opt)
   171  		}
   172  	}
   173  
   174  	if cst.ts.Config.ErrorLog == nil {
   175  		cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
   176  	}
   177  
   178  	switch mode {
   179  	case http1Mode:
   180  		cst.ts.Start()
   181  	case https1Mode:
   182  		cst.ts.StartTLS()
   183  	case http2Mode:
   184  		ExportHttp2ConfigureServer(cst.ts.Config, nil)
   185  		cst.ts.TLS = cst.ts.Config.TLSConfig
   186  		cst.ts.StartTLS()
   187  	default:
   188  		t.Fatalf("unknown test mode %v", mode)
   189  	}
   190  	cst.c = cst.ts.Client()
   191  	cst.tr = cst.c.Transport.(*Transport)
   192  	if mode == http2Mode {
   193  		if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
   194  			t.Fatal(err)
   195  		}
   196  	}
   197  	for _, f := range transportFuncs {
   198  		f(cst.tr)
   199  	}
   200  	t.Cleanup(func() {
   201  		cst.close()
   202  	})
   203  	return cst
   204  }
   205  
   206  type testLogWriter struct {
   207  	t testing.TB
   208  }
   209  
   210  func (w testLogWriter) Write(b []byte) (int, error) {
   211  	w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
   212  	return len(b), nil
   213  }
   214  
   215  // Testing the newClientServerTest helper itself.
   216  func TestNewClientServerTest(t *testing.T) {
   217  	run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode})
   218  }
   219  func testNewClientServerTest(t *testing.T, mode testMode) {
   220  	var got struct {
   221  		sync.Mutex
   222  		proto  string
   223  		hasTLS bool
   224  	}
   225  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
   226  		got.Lock()
   227  		defer got.Unlock()
   228  		got.proto = r.Proto
   229  		got.hasTLS = r.TLS != nil
   230  	})
   231  	cst := newClientServerTest(t, mode, h)
   232  	if _, err := cst.c.Head(cst.ts.URL); err != nil {
   233  		t.Fatal(err)
   234  	}
   235  	var wantProto string
   236  	var wantTLS bool
   237  	switch mode {
   238  	case http1Mode:
   239  		wantProto = "HTTP/1.1"
   240  		wantTLS = false
   241  	case https1Mode:
   242  		wantProto = "HTTP/1.1"
   243  		wantTLS = true
   244  	case http2Mode:
   245  		wantProto = "HTTP/2.0"
   246  		wantTLS = true
   247  	}
   248  	if got.proto != wantProto {
   249  		t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
   250  	}
   251  	if got.hasTLS != wantTLS {
   252  		t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
   253  	}
   254  }
   255  
   256  func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
   257  func testChunkedResponseHeaders(t *testing.T, mode testMode) {
   258  	log.SetOutput(io.Discard) // is noisy otherwise
   259  	defer log.SetOutput(os.Stderr)
   260  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   261  		w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
   262  		w.(Flusher).Flush()
   263  		fmt.Fprintf(w, "I am a chunked response.")
   264  	}))
   265  
   266  	res, err := cst.c.Get(cst.ts.URL)
   267  	if err != nil {
   268  		t.Fatalf("Get error: %v", err)
   269  	}
   270  	defer res.Body.Close()
   271  	if g, e := res.ContentLength, int64(-1); g != e {
   272  		t.Errorf("expected ContentLength of %d; got %d", e, g)
   273  	}
   274  	wantTE := []string{"chunked"}
   275  	if mode == http2Mode {
   276  		wantTE = nil
   277  	}
   278  	if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
   279  		t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
   280  	}
   281  	if got, haveCL := res.Header["Content-Length"]; haveCL {
   282  		t.Errorf("Unexpected Content-Length: %q", got)
   283  	}
   284  }
   285  
   286  type reqFunc func(c *Client, url string) (*Response, error)
   287  
   288  // h12Compare is a test that compares HTTP/1 and HTTP/2 behavior
   289  // against each other.
   290  type h12Compare struct {
   291  	Handler            func(ResponseWriter, *Request)    // required
   292  	ReqFunc            reqFunc                           // optional
   293  	CheckResponse      func(proto string, res *Response) // optional
   294  	EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize
   295  	Opts               []any
   296  }
   297  
   298  func (tt h12Compare) reqFunc() reqFunc {
   299  	if tt.ReqFunc == nil {
   300  		return (*Client).Get
   301  	}
   302  	return tt.ReqFunc
   303  }
   304  
   305  func (tt h12Compare) run(t *testing.T) {
   306  	setParallel(t)
   307  	cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
   308  	defer cst1.close()
   309  	cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
   310  	defer cst2.close()
   311  
   312  	res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
   313  	if err != nil {
   314  		t.Errorf("HTTP/1 request: %v", err)
   315  		return
   316  	}
   317  	res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
   318  	if err != nil {
   319  		t.Errorf("HTTP/2 request: %v", err)
   320  		return
   321  	}
   322  
   323  	if fn := tt.EarlyCheckResponse; fn != nil {
   324  		fn("HTTP/1.1", res1)
   325  		fn("HTTP/2.0", res2)
   326  	}
   327  
   328  	tt.normalizeRes(t, res1, "HTTP/1.1")
   329  	tt.normalizeRes(t, res2, "HTTP/2.0")
   330  	res1body, res2body := res1.Body, res2.Body
   331  
   332  	eres1 := mostlyCopy(res1)
   333  	eres2 := mostlyCopy(res2)
   334  	if !reflect.DeepEqual(eres1, eres2) {
   335  		t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
   336  			cst1.ts.URL, eres1, cst2.ts.URL, eres2)
   337  	}
   338  	if !reflect.DeepEqual(res1body, res2body) {
   339  		t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
   340  	}
   341  	if fn := tt.CheckResponse; fn != nil {
   342  		res1.Body, res2.Body = res1body, res2body
   343  		fn("HTTP/1.1", res1)
   344  		fn("HTTP/2.0", res2)
   345  	}
   346  }
   347  
   348  func mostlyCopy(r *Response) *Response {
   349  	c := *r
   350  	c.Body = nil
   351  	c.TransferEncoding = nil
   352  	c.TLS = nil
   353  	c.Request = nil
   354  	return &c
   355  }
   356  
   357  type slurpResult struct {
   358  	io.ReadCloser
   359  	body []byte
   360  	err  error
   361  }
   362  
   363  func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
   364  
   365  func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
   366  	if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
   367  		res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
   368  	} else {
   369  		t.Errorf("got %q response; want %q", res.Proto, wantProto)
   370  	}
   371  	slurp, err := io.ReadAll(res.Body)
   372  
   373  	res.Body.Close()
   374  	res.Body = slurpResult{
   375  		ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
   376  		body:       slurp,
   377  		err:        err,
   378  	}
   379  	for i, v := range res.Header["Date"] {
   380  		res.Header["Date"][i] = strings.Repeat("x", len(v))
   381  	}
   382  	if res.Request == nil {
   383  		t.Errorf("for %s, no request", wantProto)
   384  	}
   385  	if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
   386  		t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
   387  	}
   388  }
   389  
   390  // Issue 13532
   391  func TestH12_HeadContentLengthNoBody(t *testing.T) {
   392  	h12Compare{
   393  		ReqFunc: (*Client).Head,
   394  		Handler: func(w ResponseWriter, r *Request) {
   395  		},
   396  	}.run(t)
   397  }
   398  
   399  func TestH12_HeadContentLengthSmallBody(t *testing.T) {
   400  	h12Compare{
   401  		ReqFunc: (*Client).Head,
   402  		Handler: func(w ResponseWriter, r *Request) {
   403  			io.WriteString(w, "small")
   404  		},
   405  	}.run(t)
   406  }
   407  
   408  func TestH12_HeadContentLengthLargeBody(t *testing.T) {
   409  	h12Compare{
   410  		ReqFunc: (*Client).Head,
   411  		Handler: func(w ResponseWriter, r *Request) {
   412  			chunk := strings.Repeat("x", 512<<10)
   413  			for i := 0; i < 10; i++ {
   414  				io.WriteString(w, chunk)
   415  			}
   416  		},
   417  	}.run(t)
   418  }
   419  
   420  func TestH12_200NoBody(t *testing.T) {
   421  	h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
   422  }
   423  
   424  func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
   425  func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
   426  func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
   427  
   428  func testH12_noBody(t *testing.T, status int) {
   429  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   430  		w.WriteHeader(status)
   431  	}}.run(t)
   432  }
   433  
   434  func TestH12_SmallBody(t *testing.T) {
   435  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   436  		io.WriteString(w, "small body")
   437  	}}.run(t)
   438  }
   439  
   440  func TestH12_ExplicitContentLength(t *testing.T) {
   441  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   442  		w.Header().Set("Content-Length", "3")
   443  		io.WriteString(w, "foo")
   444  	}}.run(t)
   445  }
   446  
   447  func TestH12_FlushBeforeBody(t *testing.T) {
   448  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   449  		w.(Flusher).Flush()
   450  		io.WriteString(w, "foo")
   451  	}}.run(t)
   452  }
   453  
   454  func TestH12_FlushMidBody(t *testing.T) {
   455  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   456  		io.WriteString(w, "foo")
   457  		w.(Flusher).Flush()
   458  		io.WriteString(w, "bar")
   459  	}}.run(t)
   460  }
   461  
   462  func TestH12_Head_ExplicitLen(t *testing.T) {
   463  	h12Compare{
   464  		ReqFunc: (*Client).Head,
   465  		Handler: func(w ResponseWriter, r *Request) {
   466  			if r.Method != "HEAD" {
   467  				t.Errorf("unexpected method %q", r.Method)
   468  			}
   469  			w.Header().Set("Content-Length", "1235")
   470  		},
   471  	}.run(t)
   472  }
   473  
   474  func TestH12_Head_ImplicitLen(t *testing.T) {
   475  	h12Compare{
   476  		ReqFunc: (*Client).Head,
   477  		Handler: func(w ResponseWriter, r *Request) {
   478  			if r.Method != "HEAD" {
   479  				t.Errorf("unexpected method %q", r.Method)
   480  			}
   481  			io.WriteString(w, "foo")
   482  		},
   483  	}.run(t)
   484  }
   485  
   486  func TestH12_HandlerWritesTooLittle(t *testing.T) {
   487  	h12Compare{
   488  		Handler: func(w ResponseWriter, r *Request) {
   489  			w.Header().Set("Content-Length", "3")
   490  			io.WriteString(w, "12") // one byte short
   491  		},
   492  		CheckResponse: func(proto string, res *Response) {
   493  			sr, ok := res.Body.(slurpResult)
   494  			if !ok {
   495  				t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
   496  				return
   497  			}
   498  			if sr.err != io.ErrUnexpectedEOF {
   499  				t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
   500  			}
   501  			if string(sr.body) != "12" {
   502  				t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
   503  			}
   504  		},
   505  	}.run(t)
   506  }
   507  
   508  // Tests that the HTTP/1 and HTTP/2 servers prevent handlers from
   509  // writing more than they declared. This test does not test whether
   510  // the transport deals with too much data, though, since the server
   511  // doesn't make it possible to send bogus data. For those tests, see
   512  // transport_test.go (for HTTP/1) or x/net/http2/transport_test.go
   513  // (for HTTP/2).
   514  func TestH12_HandlerWritesTooMuch(t *testing.T) {
   515  	h12Compare{
   516  		Handler: func(w ResponseWriter, r *Request) {
   517  			w.Header().Set("Content-Length", "3")
   518  			w.(Flusher).Flush()
   519  			io.WriteString(w, "123")
   520  			w.(Flusher).Flush()
   521  			n, err := io.WriteString(w, "x") // too many
   522  			if n > 0 || err == nil {
   523  				t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err)
   524  			}
   525  		},
   526  	}.run(t)
   527  }
   528  
   529  // Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip.
   530  // Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298
   531  func TestH12_AutoGzip(t *testing.T) {
   532  	h12Compare{
   533  		Handler: func(w ResponseWriter, r *Request) {
   534  			if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
   535  				t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
   536  			}
   537  			w.Header().Set("Content-Encoding", "gzip")
   538  			gz := gzip.NewWriter(w)
   539  			io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
   540  			gz.Close()
   541  		},
   542  	}.run(t)
   543  }
   544  
   545  func TestH12_AutoGzip_Disabled(t *testing.T) {
   546  	h12Compare{
   547  		Opts: []any{
   548  			func(tr *Transport) { tr.DisableCompression = true },
   549  		},
   550  		Handler: func(w ResponseWriter, r *Request) {
   551  			fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
   552  			if ae := r.Header.Get("Accept-Encoding"); ae != "" {
   553  				t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
   554  			}
   555  		},
   556  	}.run(t)
   557  }
   558  
   559  // Test304Responses verifies that 304s don't declare that they're
   560  // chunking in their response headers and aren't allowed to produce
   561  // output.
   562  func Test304Responses(t *testing.T) { run(t, test304Responses) }
   563  func test304Responses(t *testing.T, mode testMode) {
   564  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   565  		w.WriteHeader(StatusNotModified)
   566  		_, err := w.Write([]byte("illegal body"))
   567  		if err != ErrBodyNotAllowed {
   568  			t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
   569  		}
   570  	}))
   571  	defer cst.close()
   572  	res, err := cst.c.Get(cst.ts.URL)
   573  	if err != nil {
   574  		t.Fatal(err)
   575  	}
   576  	if len(res.TransferEncoding) > 0 {
   577  		t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
   578  	}
   579  	body, err := io.ReadAll(res.Body)
   580  	if err != nil {
   581  		t.Error(err)
   582  	}
   583  	if len(body) > 0 {
   584  		t.Errorf("got unexpected body %q", string(body))
   585  	}
   586  }
   587  
   588  func TestH12_ServerEmptyContentLength(t *testing.T) {
   589  	h12Compare{
   590  		Handler: func(w ResponseWriter, r *Request) {
   591  			w.Header()["Content-Type"] = []string{""}
   592  			io.WriteString(w, "<html><body>hi</body></html>")
   593  		},
   594  	}.run(t)
   595  }
   596  
   597  func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
   598  	h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
   599  }
   600  
   601  func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
   602  	h12requestContentLength(t, func() io.Reader { return nil }, 0)
   603  }
   604  
   605  func TestH12_RequestContentLength_Unknown(t *testing.T) {
   606  	h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
   607  }
   608  
   609  func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
   610  	h12Compare{
   611  		Handler: func(w ResponseWriter, r *Request) {
   612  			w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
   613  			fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
   614  		},
   615  		ReqFunc: func(c *Client, url string) (*Response, error) {
   616  			return c.Post(url, "text/plain", bodyfn())
   617  		},
   618  		CheckResponse: func(proto string, res *Response) {
   619  			if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
   620  				t.Errorf("Proto %q got length %q; want %q", proto, got, want)
   621  			}
   622  		},
   623  	}.run(t)
   624  }
   625  
   626  // Tests that closing the Request.Cancel channel also while still
   627  // reading the response body. Issue 13159.
   628  func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
   629  func testCancelRequestMidBody(t *testing.T, mode testMode) {
   630  	unblock := make(chan bool)
   631  	didFlush := make(chan bool, 1)
   632  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   633  		io.WriteString(w, "Hello")
   634  		w.(Flusher).Flush()
   635  		didFlush <- true
   636  		<-unblock
   637  		io.WriteString(w, ", world.")
   638  	}))
   639  	defer close(unblock)
   640  
   641  	req, _ := NewRequest("GET", cst.ts.URL, nil)
   642  	cancel := make(chan struct{})
   643  	req.Cancel = cancel
   644  
   645  	res, err := cst.c.Do(req)
   646  	if err != nil {
   647  		t.Fatal(err)
   648  	}
   649  	defer res.Body.Close()
   650  	<-didFlush
   651  
   652  	// Read a bit before we cancel. (Issue 13626)
   653  	// We should have "Hello" at least sitting there.
   654  	firstRead := make([]byte, 10)
   655  	n, err := res.Body.Read(firstRead)
   656  	if err != nil {
   657  		t.Fatal(err)
   658  	}
   659  	firstRead = firstRead[:n]
   660  
   661  	close(cancel)
   662  
   663  	rest, err := io.ReadAll(res.Body)
   664  	all := string(firstRead) + string(rest)
   665  	if all != "Hello" {
   666  		t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
   667  	}
   668  	if err != ExportErrRequestCanceled {
   669  		t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
   670  	}
   671  }
   672  
   673  // Tests that clients can send trailers to a server and that the server can read them.
   674  func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
   675  func testTrailersClientToServer(t *testing.T, mode testMode) {
   676  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   677  		var decl []string
   678  		for k := range r.Trailer {
   679  			decl = append(decl, k)
   680  		}
   681  		sort.Strings(decl)
   682  
   683  		slurp, err := io.ReadAll(r.Body)
   684  		if err != nil {
   685  			t.Errorf("Server reading request body: %v", err)
   686  		}
   687  		if string(slurp) != "foo" {
   688  			t.Errorf("Server read request body %q; want foo", slurp)
   689  		}
   690  		if r.Trailer == nil {
   691  			io.WriteString(w, "nil Trailer")
   692  		} else {
   693  			fmt.Fprintf(w, "decl: %v, vals: %s, %s",
   694  				decl,
   695  				r.Trailer.Get("Client-Trailer-A"),
   696  				r.Trailer.Get("Client-Trailer-B"))
   697  		}
   698  	}))
   699  
   700  	var req *Request
   701  	req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
   702  		eofReaderFunc(func() {
   703  			req.Trailer["Client-Trailer-A"] = []string{"valuea"}
   704  		}),
   705  		strings.NewReader("foo"),
   706  		eofReaderFunc(func() {
   707  			req.Trailer["Client-Trailer-B"] = []string{"valueb"}
   708  		}),
   709  	))
   710  	req.Trailer = Header{
   711  		"Client-Trailer-A": nil, //  to be set later
   712  		"Client-Trailer-B": nil, //  to be set later
   713  	}
   714  	req.ContentLength = -1
   715  	res, err := cst.c.Do(req)
   716  	if err != nil {
   717  		t.Fatal(err)
   718  	}
   719  	if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
   720  		t.Error(err)
   721  	}
   722  }
   723  
   724  // Tests that servers send trailers to a client and that the client can read them.
   725  func TestTrailersServerToClient(t *testing.T) {
   726  	run(t, func(t *testing.T, mode testMode) {
   727  		testTrailersServerToClient(t, mode, false)
   728  	})
   729  }
   730  func TestTrailersServerToClientFlush(t *testing.T) {
   731  	run(t, func(t *testing.T, mode testMode) {
   732  		testTrailersServerToClient(t, mode, true)
   733  	})
   734  }
   735  
   736  func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
   737  	const body = "Some body"
   738  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   739  		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
   740  		w.Header().Add("Trailer", "Server-Trailer-C")
   741  
   742  		io.WriteString(w, body)
   743  		if flush {
   744  			w.(Flusher).Flush()
   745  		}
   746  
   747  		// How handlers set Trailers: declare it ahead of time
   748  		// with the Trailer header, and then mutate the
   749  		// Header() of those values later, after the response
   750  		// has been written (we wrote to w above).
   751  		w.Header().Set("Server-Trailer-A", "valuea")
   752  		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
   753  		w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
   754  	}))
   755  
   756  	res, err := cst.c.Get(cst.ts.URL)
   757  	if err != nil {
   758  		t.Fatal(err)
   759  	}
   760  
   761  	wantHeader := Header{
   762  		"Content-Type": {"text/plain; charset=utf-8"},
   763  	}
   764  	wantLen := -1
   765  	if mode == http2Mode && !flush {
   766  		// In HTTP/1.1, any use of trailers forces HTTP/1.1
   767  		// chunking and a flush at the first write. That's
   768  		// unnecessary with HTTP/2's framing, so the server
   769  		// is able to calculate the length while still sending
   770  		// trailers afterwards.
   771  		wantLen = len(body)
   772  		wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
   773  	}
   774  	if res.ContentLength != int64(wantLen) {
   775  		t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
   776  	}
   777  
   778  	delete(res.Header, "Date") // irrelevant for test
   779  	if !reflect.DeepEqual(res.Header, wantHeader) {
   780  		t.Errorf("Header = %v; want %v", res.Header, wantHeader)
   781  	}
   782  
   783  	if got, want := res.Trailer, (Header{
   784  		"Server-Trailer-A": nil,
   785  		"Server-Trailer-B": nil,
   786  		"Server-Trailer-C": nil,
   787  	}); !reflect.DeepEqual(got, want) {
   788  		t.Errorf("Trailer before body read = %v; want %v", got, want)
   789  	}
   790  
   791  	if err := wantBody(res, nil, body); err != nil {
   792  		t.Fatal(err)
   793  	}
   794  
   795  	if got, want := res.Trailer, (Header{
   796  		"Server-Trailer-A": {"valuea"},
   797  		"Server-Trailer-B": nil,
   798  		"Server-Trailer-C": {"valuec"},
   799  	}); !reflect.DeepEqual(got, want) {
   800  		t.Errorf("Trailer after body read = %v; want %v", got, want)
   801  	}
   802  }
   803  
   804  // Don't allow a Body.Read after Body.Close. Issue 13648.
   805  func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
   806  func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
   807  	const body = "Some body"
   808  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   809  		io.WriteString(w, body)
   810  	}))
   811  	res, err := cst.c.Get(cst.ts.URL)
   812  	if err != nil {
   813  		t.Fatal(err)
   814  	}
   815  	res.Body.Close()
   816  	data, err := io.ReadAll(res.Body)
   817  	if len(data) != 0 || err == nil {
   818  		t.Fatalf("ReadAll returned %q, %v; want error", data, err)
   819  	}
   820  }
   821  
   822  func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
   823  func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
   824  	const reqBody = "some request body"
   825  	const resBody = "some response body"
   826  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   827  		var wg sync.WaitGroup
   828  		wg.Add(2)
   829  		didRead := make(chan bool, 1)
   830  		// Read in one goroutine.
   831  		go func() {
   832  			defer wg.Done()
   833  			data, err := io.ReadAll(r.Body)
   834  			if string(data) != reqBody {
   835  				t.Errorf("Handler read %q; want %q", data, reqBody)
   836  			}
   837  			if err != nil {
   838  				t.Errorf("Handler Read: %v", err)
   839  			}
   840  			didRead <- true
   841  		}()
   842  		// Write in another goroutine.
   843  		go func() {
   844  			defer wg.Done()
   845  			if mode != http2Mode {
   846  				// our HTTP/1 implementation intentionally
   847  				// doesn't permit writes during read (mostly
   848  				// due to it being undefined); if that is ever
   849  				// relaxed, change this.
   850  				<-didRead
   851  			}
   852  			io.WriteString(w, resBody)
   853  		}()
   854  		wg.Wait()
   855  	}))
   856  	req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
   857  	req.Header.Add("Expect", "100-continue") // just to complicate things
   858  	res, err := cst.c.Do(req)
   859  	if err != nil {
   860  		t.Fatal(err)
   861  	}
   862  	data, err := io.ReadAll(res.Body)
   863  	defer res.Body.Close()
   864  	if err != nil {
   865  		t.Fatal(err)
   866  	}
   867  	if string(data) != resBody {
   868  		t.Errorf("read %q; want %q", data, resBody)
   869  	}
   870  }
   871  
   872  func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
   873  func testConnectRequest(t *testing.T, mode testMode) {
   874  	gotc := make(chan *Request, 1)
   875  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   876  		gotc <- r
   877  	}))
   878  
   879  	u, err := url.Parse(cst.ts.URL)
   880  	if err != nil {
   881  		t.Fatal(err)
   882  	}
   883  
   884  	tests := []struct {
   885  		req  *Request
   886  		want string
   887  	}{
   888  		{
   889  			req: &Request{
   890  				Method: "CONNECT",
   891  				Header: Header{},
   892  				URL:    u,
   893  			},
   894  			want: u.Host,
   895  		},
   896  		{
   897  			req: &Request{
   898  				Method: "CONNECT",
   899  				Header: Header{},
   900  				URL:    u,
   901  				Host:   "example.com:123",
   902  			},
   903  			want: "example.com:123",
   904  		},
   905  	}
   906  
   907  	for i, tt := range tests {
   908  		res, err := cst.c.Do(tt.req)
   909  		if err != nil {
   910  			t.Errorf("%d. RoundTrip = %v", i, err)
   911  			continue
   912  		}
   913  		res.Body.Close()
   914  		req := <-gotc
   915  		if req.Method != "CONNECT" {
   916  			t.Errorf("method = %q; want CONNECT", req.Method)
   917  		}
   918  		if req.Host != tt.want {
   919  			t.Errorf("Host = %q; want %q", req.Host, tt.want)
   920  		}
   921  		if req.URL.Host != tt.want {
   922  			t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
   923  		}
   924  	}
   925  }
   926  
   927  func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
   928  func testTransportUserAgent(t *testing.T, mode testMode) {
   929  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   930  		fmt.Fprintf(w, "%q", r.Header["User-Agent"])
   931  	}))
   932  
   933  	either := func(a, b string) string {
   934  		if mode == http2Mode {
   935  			return b
   936  		}
   937  		return a
   938  	}
   939  
   940  	tests := []struct {
   941  		setup func(*Request)
   942  		want  string
   943  	}{
   944  		{
   945  			func(r *Request) {},
   946  			either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
   947  		},
   948  		{
   949  			func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
   950  			`["foo/1.2.3"]`,
   951  		},
   952  		{
   953  			func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
   954  			`["single"]`,
   955  		},
   956  		{
   957  			func(r *Request) { r.Header.Set("User-Agent", "") },
   958  			`[]`,
   959  		},
   960  		{
   961  			func(r *Request) { r.Header["User-Agent"] = nil },
   962  			`[]`,
   963  		},
   964  	}
   965  	for i, tt := range tests {
   966  		req, _ := NewRequest("GET", cst.ts.URL, nil)
   967  		tt.setup(req)
   968  		res, err := cst.c.Do(req)
   969  		if err != nil {
   970  			t.Errorf("%d. RoundTrip = %v", i, err)
   971  			continue
   972  		}
   973  		slurp, err := io.ReadAll(res.Body)
   974  		res.Body.Close()
   975  		if err != nil {
   976  			t.Errorf("%d. read body = %v", i, err)
   977  			continue
   978  		}
   979  		if string(slurp) != tt.want {
   980  			t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
   981  		}
   982  	}
   983  }
   984  
   985  func TestStarRequestMethod(t *testing.T) {
   986  	for _, method := range []string{"FOO", "OPTIONS"} {
   987  		t.Run(method, func(t *testing.T) {
   988  			run(t, func(t *testing.T, mode testMode) {
   989  				testStarRequest(t, method, mode)
   990  			})
   991  		})
   992  	}
   993  }
   994  func testStarRequest(t *testing.T, method string, mode testMode) {
   995  	gotc := make(chan *Request, 1)
   996  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   997  		w.Header().Set("foo", "bar")
   998  		gotc <- r
   999  		w.(Flusher).Flush()
  1000  	}))
  1001  
  1002  	u, err := url.Parse(cst.ts.URL)
  1003  	if err != nil {
  1004  		t.Fatal(err)
  1005  	}
  1006  	u.Path = "*"
  1007  
  1008  	req := &Request{
  1009  		Method: method,
  1010  		Header: Header{},
  1011  		URL:    u,
  1012  	}
  1013  
  1014  	res, err := cst.c.Do(req)
  1015  	if err != nil {
  1016  		t.Fatalf("RoundTrip = %v", err)
  1017  	}
  1018  	res.Body.Close()
  1019  
  1020  	wantFoo := "bar"
  1021  	wantLen := int64(-1)
  1022  	if method == "OPTIONS" {
  1023  		wantFoo = ""
  1024  		wantLen = 0
  1025  	}
  1026  	if res.StatusCode != 200 {
  1027  		t.Errorf("status code = %v; want %d", res.Status, 200)
  1028  	}
  1029  	if res.ContentLength != wantLen {
  1030  		t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
  1031  	}
  1032  	if got := res.Header.Get("foo"); got != wantFoo {
  1033  		t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
  1034  	}
  1035  	select {
  1036  	case req = <-gotc:
  1037  	default:
  1038  		req = nil
  1039  	}
  1040  	if req == nil {
  1041  		if method != "OPTIONS" {
  1042  			t.Fatalf("handler never got request")
  1043  		}
  1044  		return
  1045  	}
  1046  	if req.Method != method {
  1047  		t.Errorf("method = %q; want %q", req.Method, method)
  1048  	}
  1049  	if req.URL.Path != "*" {
  1050  		t.Errorf("URL.Path = %q; want *", req.URL.Path)
  1051  	}
  1052  	if req.RequestURI != "*" {
  1053  		t.Errorf("RequestURI = %q; want *", req.RequestURI)
  1054  	}
  1055  }
  1056  
  1057  // Issue 13957
  1058  func TestTransportDiscardsUnneededConns(t *testing.T) {
  1059  	run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
  1060  }
  1061  func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
  1062  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1063  		fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
  1064  	}))
  1065  	defer cst.close()
  1066  
  1067  	var numOpen, numClose int32 // atomic
  1068  
  1069  	tlsConfig := &tls.Config{InsecureSkipVerify: true}
  1070  	tr := &Transport{
  1071  		TLSClientConfig: tlsConfig,
  1072  		DialTLS: func(_, addr string) (net.Conn, error) {
  1073  			time.Sleep(10 * time.Millisecond)
  1074  			rc, err := net.Dial("tcp", addr)
  1075  			if err != nil {
  1076  				return nil, err
  1077  			}
  1078  			atomic.AddInt32(&numOpen, 1)
  1079  			c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
  1080  			return TLSClientFactory(c, tlsConfig), nil
  1081  		},
  1082  	}
  1083  	if err := ExportHttp2ConfigureTransport(tr); err != nil {
  1084  		t.Fatal(err)
  1085  	}
  1086  	defer tr.CloseIdleConnections()
  1087  
  1088  	c := &Client{Transport: tr}
  1089  
  1090  	const N = 10
  1091  	gotBody := make(chan string, N)
  1092  	var wg sync.WaitGroup
  1093  	for i := 0; i < N; i++ {
  1094  		wg.Add(1)
  1095  		go func() {
  1096  			defer wg.Done()
  1097  			resp, err := c.Get(cst.ts.URL)
  1098  			if err != nil {
  1099  				// Try to work around spurious connection reset on loaded system.
  1100  				// See golang.org/issue/33585 and golang.org/issue/36797.
  1101  				time.Sleep(10 * time.Millisecond)
  1102  				resp, err = c.Get(cst.ts.URL)
  1103  				if err != nil {
  1104  					t.Errorf("Get: %v", err)
  1105  					return
  1106  				}
  1107  			}
  1108  			defer resp.Body.Close()
  1109  			slurp, err := io.ReadAll(resp.Body)
  1110  			if err != nil {
  1111  				t.Error(err)
  1112  			}
  1113  			gotBody <- string(slurp)
  1114  		}()
  1115  	}
  1116  	wg.Wait()
  1117  	close(gotBody)
  1118  
  1119  	var last string
  1120  	for got := range gotBody {
  1121  		if last == "" {
  1122  			last = got
  1123  			continue
  1124  		}
  1125  		if got != last {
  1126  			t.Errorf("Response body changed: %q -> %q", last, got)
  1127  		}
  1128  	}
  1129  
  1130  	var open, close int32
  1131  	for i := 0; i < 150; i++ {
  1132  		open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
  1133  		if open < 1 {
  1134  			t.Fatalf("open = %d; want at least", open)
  1135  		}
  1136  		if close == open-1 {
  1137  			// Success
  1138  			return
  1139  		}
  1140  		time.Sleep(10 * time.Millisecond)
  1141  	}
  1142  	t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
  1143  }
  1144  
  1145  // tests that Transport doesn't retain a pointer to the provided request.
  1146  func TestTransportGCRequest(t *testing.T) {
  1147  	run(t, func(t *testing.T, mode testMode) {
  1148  		t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
  1149  		t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
  1150  	})
  1151  }
  1152  func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
  1153  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1154  		io.ReadAll(r.Body)
  1155  		if body {
  1156  			io.WriteString(w, "Hello.")
  1157  		}
  1158  	}))
  1159  
  1160  	didGC := make(chan struct{})
  1161  	(func() {
  1162  		body := strings.NewReader("some body")
  1163  		req, _ := NewRequest("POST", cst.ts.URL, body)
  1164  		runtime.SetFinalizer(req, func(*Request) { close(didGC) })
  1165  		res, err := cst.c.Do(req)
  1166  		if err != nil {
  1167  			t.Fatal(err)
  1168  		}
  1169  		if _, err := io.ReadAll(res.Body); err != nil {
  1170  			t.Fatal(err)
  1171  		}
  1172  		if err := res.Body.Close(); err != nil {
  1173  			t.Fatal(err)
  1174  		}
  1175  	})()
  1176  	timeout := time.NewTimer(5 * time.Second)
  1177  	defer timeout.Stop()
  1178  	for {
  1179  		select {
  1180  		case <-didGC:
  1181  			return
  1182  		case <-time.After(100 * time.Millisecond):
  1183  			runtime.GC()
  1184  		case <-timeout.C:
  1185  			t.Fatal("never saw GC of request")
  1186  		}
  1187  	}
  1188  }
  1189  
  1190  func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
  1191  func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
  1192  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1193  		fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
  1194  	}), optQuietLog)
  1195  	cst.tr.DisableKeepAlives = true
  1196  
  1197  	tests := []struct {
  1198  		key, val string
  1199  		ok       bool
  1200  	}{
  1201  		{"Foo", "capital-key", true}, // verify h2 allows capital keys
  1202  		{"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed
  1203  		{"Foo", "two\nlines", false}, // \n byte in value not allowed
  1204  		{"bogus\nkey", "v", false},   // \n byte also not allowed in key
  1205  		{"A space", "v", false},      // spaces in keys not allowed
  1206  		{"имя", "v", false},          // key must be ascii
  1207  		{"name", "валю", true},       // value may be non-ascii
  1208  		{"", "v", false},             // key must be non-empty
  1209  		{"k", "", true},              // value may be empty
  1210  	}
  1211  	for _, tt := range tests {
  1212  		dialedc := make(chan bool, 1)
  1213  		cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
  1214  			dialedc <- true
  1215  			return net.Dial(netw, addr)
  1216  		}
  1217  		req, _ := NewRequest("GET", cst.ts.URL, nil)
  1218  		req.Header[tt.key] = []string{tt.val}
  1219  		res, err := cst.c.Do(req)
  1220  		var body []byte
  1221  		if err == nil {
  1222  			body, _ = io.ReadAll(res.Body)
  1223  			res.Body.Close()
  1224  		}
  1225  		var dialed bool
  1226  		select {
  1227  		case <-dialedc:
  1228  			dialed = true
  1229  		default:
  1230  		}
  1231  
  1232  		if !tt.ok && dialed {
  1233  			t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
  1234  		} else if (err == nil) != tt.ok {
  1235  			t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
  1236  		}
  1237  	}
  1238  }
  1239  
  1240  func TestInterruptWithPanic(t *testing.T) {
  1241  	run(t, func(t *testing.T, mode testMode) {
  1242  		t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
  1243  		t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
  1244  		t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
  1245  	}, testNotParallel)
  1246  }
  1247  func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
  1248  	const msg = "hello"
  1249  
  1250  	testDone := make(chan struct{})
  1251  	defer close(testDone)
  1252  
  1253  	var errorLog lockedBytesBuffer
  1254  	gotHeaders := make(chan bool, 1)
  1255  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1256  		io.WriteString(w, msg)
  1257  		w.(Flusher).Flush()
  1258  
  1259  		select {
  1260  		case <-gotHeaders:
  1261  		case <-testDone:
  1262  		}
  1263  		panic(panicValue)
  1264  	}), func(ts *httptest.Server) {
  1265  		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
  1266  	})
  1267  	res, err := cst.c.Get(cst.ts.URL)
  1268  	if err != nil {
  1269  		t.Fatal(err)
  1270  	}
  1271  	gotHeaders <- true
  1272  	defer res.Body.Close()
  1273  	slurp, err := io.ReadAll(res.Body)
  1274  	if string(slurp) != msg {
  1275  		t.Errorf("client read %q; want %q", slurp, msg)
  1276  	}
  1277  	if err == nil {
  1278  		t.Errorf("client read all successfully; want some error")
  1279  	}
  1280  	logOutput := func() string {
  1281  		errorLog.Lock()
  1282  		defer errorLog.Unlock()
  1283  		return errorLog.String()
  1284  	}
  1285  	wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
  1286  
  1287  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
  1288  		gotLog := logOutput()
  1289  		if !wantStackLogged {
  1290  			if gotLog == "" {
  1291  				return true
  1292  			}
  1293  			t.Fatalf("want no log output; got: %s", gotLog)
  1294  		}
  1295  		if gotLog == "" {
  1296  			if d > 0 {
  1297  				t.Logf("wanted a stack trace logged; got nothing after %v", d)
  1298  			}
  1299  			return false
  1300  		}
  1301  		if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
  1302  			if d > 0 {
  1303  				t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
  1304  			}
  1305  			return false
  1306  		}
  1307  		return true
  1308  	})
  1309  }
  1310  
  1311  type lockedBytesBuffer struct {
  1312  	sync.Mutex
  1313  	bytes.Buffer
  1314  }
  1315  
  1316  func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
  1317  	b.Lock()
  1318  	defer b.Unlock()
  1319  	return b.Buffer.Write(p)
  1320  }
  1321  
  1322  // Issue 15366
  1323  func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
  1324  	h12Compare{
  1325  		Handler: func(w ResponseWriter, r *Request) {
  1326  			h := w.Header()
  1327  			h.Set("Content-Encoding", "gzip")
  1328  			h.Set("Content-Length", "23")
  1329  			io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
  1330  		},
  1331  		EarlyCheckResponse: func(proto string, res *Response) {
  1332  			if !res.Uncompressed {
  1333  				t.Errorf("%s: expected Uncompressed to be set", proto)
  1334  			}
  1335  			dump, err := httputil.DumpResponse(res, true)
  1336  			if err != nil {
  1337  				t.Errorf("%s: DumpResponse: %v", proto, err)
  1338  				return
  1339  			}
  1340  			if strings.Contains(string(dump), "Connection: close") {
  1341  				t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
  1342  			}
  1343  			if !strings.Contains(string(dump), "FOO") {
  1344  				t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
  1345  			}
  1346  		},
  1347  	}.run(t)
  1348  }
  1349  
  1350  // Issue 14607
  1351  func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
  1352  func testCloseIdleConnections(t *testing.T, mode testMode) {
  1353  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1354  		w.Header().Set("X-Addr", r.RemoteAddr)
  1355  	}))
  1356  	get := func() string {
  1357  		res, err := cst.c.Get(cst.ts.URL)
  1358  		if err != nil {
  1359  			t.Fatal(err)
  1360  		}
  1361  		res.Body.Close()
  1362  		v := res.Header.Get("X-Addr")
  1363  		if v == "" {
  1364  			t.Fatal("didn't get X-Addr")
  1365  		}
  1366  		return v
  1367  	}
  1368  	a1 := get()
  1369  	cst.tr.CloseIdleConnections()
  1370  	a2 := get()
  1371  	if a1 == a2 {
  1372  		t.Errorf("didn't close connection")
  1373  	}
  1374  }
  1375  
  1376  type noteCloseConn struct {
  1377  	net.Conn
  1378  	closeFunc func()
  1379  }
  1380  
  1381  func (x noteCloseConn) Close() error {
  1382  	x.closeFunc()
  1383  	return x.Conn.Close()
  1384  }
  1385  
  1386  type testErrorReader struct{ t *testing.T }
  1387  
  1388  func (r testErrorReader) Read(p []byte) (n int, err error) {
  1389  	r.t.Error("unexpected Read call")
  1390  	return 0, io.EOF
  1391  }
  1392  
  1393  func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
  1394  func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
  1395  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1396  		w.WriteHeader(StatusUnauthorized)
  1397  	}))
  1398  
  1399  	// Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it.
  1400  	cst.tr.ExpectContinueTimeout = 10 * time.Second
  1401  
  1402  	req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
  1403  	if err != nil {
  1404  		t.Fatal(err)
  1405  	}
  1406  	req.ContentLength = 0 // so transport is tempted to sniff it
  1407  	req.Header.Set("Expect", "100-continue")
  1408  	res, err := cst.tr.RoundTrip(req)
  1409  	if err != nil {
  1410  		t.Fatal(err)
  1411  	}
  1412  	defer res.Body.Close()
  1413  	if res.StatusCode != StatusUnauthorized {
  1414  		t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
  1415  	}
  1416  }
  1417  
  1418  func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
  1419  func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
  1420  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1421  		w.Header().Set("Foo", "Bar")
  1422  		w.Header().Set("Trailer:Foo", "Baz")
  1423  		w.(Flusher).Flush()
  1424  		w.Header().Add("Trailer:Foo", "Baz2")
  1425  		w.Header().Set("Trailer:Bar", "Quux")
  1426  	}))
  1427  	res, err := cst.c.Get(cst.ts.URL)
  1428  	if err != nil {
  1429  		t.Fatal(err)
  1430  	}
  1431  	if _, err := io.Copy(io.Discard, res.Body); err != nil {
  1432  		t.Fatal(err)
  1433  	}
  1434  	res.Body.Close()
  1435  	delete(res.Header, "Date")
  1436  	delete(res.Header, "Content-Type")
  1437  
  1438  	if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
  1439  		t.Errorf("Header = %#v; want %#v", res.Header, want)
  1440  	}
  1441  	if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
  1442  		t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
  1443  	}
  1444  }
  1445  
  1446  func TestBadResponseAfterReadingBody(t *testing.T) {
  1447  	run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
  1448  }
  1449  func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
  1450  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1451  		_, err := io.Copy(io.Discard, r.Body)
  1452  		if err != nil {
  1453  			t.Fatal(err)
  1454  		}
  1455  		c, _, err := w.(Hijacker).Hijack()
  1456  		if err != nil {
  1457  			t.Fatal(err)
  1458  		}
  1459  		defer c.Close()
  1460  		fmt.Fprintln(c, "some bogus crap")
  1461  	}))
  1462  
  1463  	closes := 0
  1464  	res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
  1465  	if err == nil {
  1466  		res.Body.Close()
  1467  		t.Fatal("expected an error to be returned from Post")
  1468  	}
  1469  	if closes != 1 {
  1470  		t.Errorf("closes = %d; want 1", closes)
  1471  	}
  1472  }
  1473  
  1474  func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
  1475  func testWriteHeader0(t *testing.T, mode testMode) {
  1476  	gotpanic := make(chan bool, 1)
  1477  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1478  		defer close(gotpanic)
  1479  		defer func() {
  1480  			if e := recover(); e != nil {
  1481  				got := fmt.Sprintf("%T, %v", e, e)
  1482  				want := "string, invalid WriteHeader code 0"
  1483  				if got != want {
  1484  					t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
  1485  				}
  1486  				gotpanic <- true
  1487  
  1488  				// Set an explicit 503. This also tests that the WriteHeader call panics
  1489  				// before it recorded that an explicit value was set and that bogus
  1490  				// value wasn't stuck.
  1491  				w.WriteHeader(503)
  1492  			}
  1493  		}()
  1494  		w.WriteHeader(0)
  1495  	}))
  1496  	res, err := cst.c.Get(cst.ts.URL)
  1497  	if err != nil {
  1498  		t.Fatal(err)
  1499  	}
  1500  	if res.StatusCode != 503 {
  1501  		t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
  1502  	}
  1503  	if !<-gotpanic {
  1504  		t.Error("expected panic in handler")
  1505  	}
  1506  }
  1507  
  1508  // Issue 23010: don't be super strict checking WriteHeader's code if
  1509  // it's not even valid to call WriteHeader then anyway.
  1510  func TestWriteHeaderNoCodeCheck(t *testing.T) {
  1511  	run(t, func(t *testing.T, mode testMode) {
  1512  		testWriteHeaderAfterWrite(t, mode, false)
  1513  	})
  1514  }
  1515  func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
  1516  	testWriteHeaderAfterWrite(t, http1Mode, true)
  1517  }
  1518  func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
  1519  	var errorLog lockedBytesBuffer
  1520  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1521  		if hijack {
  1522  			conn, _, _ := w.(Hijacker).Hijack()
  1523  			defer conn.Close()
  1524  			conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
  1525  			w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
  1526  			conn.Write([]byte("bar"))
  1527  			return
  1528  		}
  1529  		io.WriteString(w, "foo")
  1530  		w.(Flusher).Flush()
  1531  		w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
  1532  		io.WriteString(w, "bar")
  1533  	}), func(ts *httptest.Server) {
  1534  		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
  1535  	})
  1536  	res, err := cst.c.Get(cst.ts.URL)
  1537  	if err != nil {
  1538  		t.Fatal(err)
  1539  	}
  1540  	defer res.Body.Close()
  1541  	body, err := io.ReadAll(res.Body)
  1542  	if err != nil {
  1543  		t.Fatal(err)
  1544  	}
  1545  	if got, want := string(body), "foobar"; got != want {
  1546  		t.Errorf("got = %q; want %q", got, want)
  1547  	}
  1548  
  1549  	// Also check the stderr output:
  1550  	if mode == http2Mode {
  1551  		// TODO: also emit this log message for HTTP/2?
  1552  		// We historically haven't, so don't check.
  1553  		return
  1554  	}
  1555  	gotLog := strings.TrimSpace(errorLog.String())
  1556  	wantLog := "http: superfluous response.WriteHeader call from github.com/ooni/oohttp.relevantCaller (server.go:"
  1557  	if hijack {
  1558  		wantLog = "http: response.WriteHeader on hijacked connection from github.com/ooni/oohttp.relevantCaller (server.go:"
  1559  	}
  1560  	if !strings.HasPrefix(gotLog, wantLog) {
  1561  		t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
  1562  	}
  1563  }
  1564  
  1565  func TestBidiStreamReverseProxy(t *testing.T) {
  1566  	run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
  1567  }
  1568  func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
  1569  	backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1570  		if _, err := io.Copy(w, r.Body); err != nil {
  1571  			log.Printf("bidi backend copy: %v", err)
  1572  		}
  1573  	}))
  1574  
  1575  	backURL, err := url.Parse(backend.ts.URL)
  1576  	if err != nil {
  1577  		t.Fatal(err)
  1578  	}
  1579  	rp := httputil.NewSingleHostReverseProxy(backURL)
  1580  	rp.Transport = backend.tr
  1581  	proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1582  		rp.ServeHTTP(w, r)
  1583  	}))
  1584  
  1585  	bodyRes := make(chan any, 1) // error or hash.Hash
  1586  	pr, pw := io.Pipe()
  1587  	req, _ := NewRequest("PUT", proxy.ts.URL, pr)
  1588  	const size = 4 << 20
  1589  	go func() {
  1590  		h := sha1.New()
  1591  		_, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
  1592  		go pw.Close()
  1593  		if err != nil {
  1594  			bodyRes <- err
  1595  		} else {
  1596  			bodyRes <- h
  1597  		}
  1598  	}()
  1599  	res, err := backend.c.Do(req)
  1600  	if err != nil {
  1601  		t.Fatal(err)
  1602  	}
  1603  	defer res.Body.Close()
  1604  	hgot := sha1.New()
  1605  	n, err := io.Copy(hgot, res.Body)
  1606  	if err != nil {
  1607  		t.Fatal(err)
  1608  	}
  1609  	if n != size {
  1610  		t.Fatalf("got %d bytes; want %d", n, size)
  1611  	}
  1612  	select {
  1613  	case v := <-bodyRes:
  1614  		switch v := v.(type) {
  1615  		default:
  1616  			t.Fatalf("body copy: %v", err)
  1617  		case hash.Hash:
  1618  			if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
  1619  				t.Errorf("written bytes didn't match received bytes")
  1620  			}
  1621  		}
  1622  	case <-time.After(10 * time.Second):
  1623  		t.Fatal("timeout")
  1624  	}
  1625  
  1626  }
  1627  
  1628  // Always use HTTP/1.1 for WebSocket upgrades.
  1629  func TestH12_WebSocketUpgrade(t *testing.T) {
  1630  	h12Compare{
  1631  		Handler: func(w ResponseWriter, r *Request) {
  1632  			h := w.Header()
  1633  			h.Set("Foo", "bar")
  1634  		},
  1635  		ReqFunc: func(c *Client, url string) (*Response, error) {
  1636  			req, _ := NewRequest("GET", url, nil)
  1637  			req.Header.Set("Connection", "Upgrade")
  1638  			req.Header.Set("Upgrade", "WebSocket")
  1639  			return c.Do(req)
  1640  		},
  1641  		EarlyCheckResponse: func(proto string, res *Response) {
  1642  			if res.Proto != "HTTP/1.1" {
  1643  				t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
  1644  			}
  1645  			res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0
  1646  		},
  1647  	}.run(t)
  1648  }
  1649  
  1650  func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
  1651  func testIdentityTransferEncoding(t *testing.T, mode testMode) {
  1652  	const body = "body"
  1653  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1654  		gotBody, _ := io.ReadAll(r.Body)
  1655  		if got, want := string(gotBody), body; got != want {
  1656  			t.Errorf("got request body = %q; want %q", got, want)
  1657  		}
  1658  		w.Header().Set("Transfer-Encoding", "identity")
  1659  		w.WriteHeader(StatusOK)
  1660  		w.(Flusher).Flush()
  1661  		io.WriteString(w, body)
  1662  	}))
  1663  	req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
  1664  	res, err := cst.c.Do(req)
  1665  	if err != nil {
  1666  		t.Fatal(err)
  1667  	}
  1668  	defer res.Body.Close()
  1669  	gotBody, err := io.ReadAll(res.Body)
  1670  	if err != nil {
  1671  		t.Fatal(err)
  1672  	}
  1673  	if got, want := string(gotBody), body; got != want {
  1674  		t.Errorf("got response body = %q; want %q", got, want)
  1675  	}
  1676  }
  1677  
  1678  func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
  1679  func testEarlyHintsRequest(t *testing.T, mode testMode) {
  1680  	var wg sync.WaitGroup
  1681  	wg.Add(1)
  1682  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1683  		h := w.Header()
  1684  
  1685  		h.Add("Content-Length", "123") // must be ignored
  1686  		h.Add("Link", "</style.css>; rel=preload; as=style")
  1687  		h.Add("Link", "</script.js>; rel=preload; as=script")
  1688  		w.WriteHeader(StatusEarlyHints)
  1689  
  1690  		wg.Wait()
  1691  
  1692  		h.Add("Link", "</foo.js>; rel=preload; as=script")
  1693  		w.WriteHeader(StatusEarlyHints)
  1694  
  1695  		w.Write([]byte("Hello"))
  1696  	}))
  1697  
  1698  	checkLinkHeaders := func(t *testing.T, expected, got []string) {
  1699  		t.Helper()
  1700  
  1701  		if len(expected) != len(got) {
  1702  			t.Errorf("got %d expected %d", len(got), len(expected))
  1703  		}
  1704  
  1705  		for i := range expected {
  1706  			if expected[i] != got[i] {
  1707  				t.Errorf("got %q expected %q", got[i], expected[i])
  1708  			}
  1709  		}
  1710  	}
  1711  
  1712  	checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
  1713  		t.Helper()
  1714  
  1715  		for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
  1716  			if v, ok := header[h]; ok {
  1717  				t.Errorf("%s is %q; must not be sent", h, v)
  1718  			}
  1719  		}
  1720  	}
  1721  
  1722  	var respCounter uint8
  1723  	trace := &httptrace.ClientTrace{
  1724  		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  1725  			switch respCounter {
  1726  			case 0:
  1727  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
  1728  				checkExcludedHeaders(t, header)
  1729  
  1730  				wg.Done()
  1731  			case 1:
  1732  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
  1733  				checkExcludedHeaders(t, header)
  1734  
  1735  			default:
  1736  				t.Error("Unexpected 1xx response")
  1737  			}
  1738  
  1739  			respCounter++
  1740  
  1741  			return nil
  1742  		},
  1743  	}
  1744  	req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
  1745  
  1746  	res, err := cst.c.Do(req)
  1747  	if err != nil {
  1748  		t.Fatal(err)
  1749  	}
  1750  	defer res.Body.Close()
  1751  
  1752  	checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
  1753  	if cl := res.Header.Get("Content-Length"); cl != "123" {
  1754  		t.Errorf("Content-Length is %q; want 123", cl)
  1755  	}
  1756  
  1757  	body, _ := io.ReadAll(res.Body)
  1758  	if string(body) != "Hello" {
  1759  		t.Errorf("Read body %q; want Hello", body)
  1760  	}
  1761  }