golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/http2/clientconn_test.go (about)

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Infrastructure for testing ClientConn.RoundTrip.
     6  // Put actual tests in transport_test.go.
     7  
     8  package http2
     9  
    10  import (
    11  	"bytes"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"net/http"
    16  	"reflect"
    17  	"slices"
    18  	"testing"
    19  	"time"
    20  
    21  	"golang.org/x/net/http2/hpack"
    22  )
    23  
    24  // TestTestClientConn demonstrates usage of testClientConn.
    25  func TestTestClientConn(t *testing.T) {
    26  	// newTestClientConn creates a *ClientConn and surrounding test infrastructure.
    27  	tc := newTestClientConn(t)
    28  
    29  	// tc.greet reads the client's initial SETTINGS and WINDOW_UPDATE frames,
    30  	// and sends a SETTINGS frame to the client.
    31  	//
    32  	// Additional settings may be provided as optional parameters to greet.
    33  	tc.greet()
    34  
    35  	// Request bodies must either be constant (bytes.Buffer, strings.Reader)
    36  	// or created with newRequestBody.
    37  	body := tc.newRequestBody()
    38  	body.writeBytes(10)         // 10 arbitrary bytes...
    39  	body.closeWithError(io.EOF) // ...followed by EOF.
    40  
    41  	// tc.roundTrip calls RoundTrip, but does not wait for it to return.
    42  	// It returns a testRoundTrip.
    43  	req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
    44  	rt := tc.roundTrip(req)
    45  
    46  	// tc has a number of methods to check for expected frames sent.
    47  	// Here, we look for headers and the request body.
    48  	tc.wantHeaders(wantHeader{
    49  		streamID:  rt.streamID(),
    50  		endStream: false,
    51  		header: http.Header{
    52  			":authority": []string{"dummy.tld"},
    53  			":method":    []string{"PUT"},
    54  			":path":      []string{"/"},
    55  		},
    56  	})
    57  	// Expect 10 bytes of request body in DATA frames.
    58  	tc.wantData(wantData{
    59  		streamID:  rt.streamID(),
    60  		endStream: true,
    61  		size:      10,
    62  	})
    63  
    64  	// tc.writeHeaders sends a HEADERS frame back to the client.
    65  	tc.writeHeaders(HeadersFrameParam{
    66  		StreamID:   rt.streamID(),
    67  		EndHeaders: true,
    68  		EndStream:  true,
    69  		BlockFragment: tc.makeHeaderBlockFragment(
    70  			":status", "200",
    71  		),
    72  	})
    73  
    74  	// Now that we've received headers, RoundTrip has finished.
    75  	// testRoundTrip has various methods to examine the response,
    76  	// or to fetch the response and/or error returned by RoundTrip
    77  	rt.wantStatus(200)
    78  	rt.wantBody(nil)
    79  }
    80  
    81  // A testClientConn allows testing ClientConn.RoundTrip against a fake server.
    82  //
    83  // A test using testClientConn consists of:
    84  //   - actions on the client (calling RoundTrip, making data available to Request.Body);
    85  //   - validation of frames sent by the client to the server; and
    86  //   - providing frames from the server to the client.
    87  //
    88  // testClientConn manages synchronization, so tests can generally be written as
    89  // a linear sequence of actions and validations without additional synchronization.
    90  type testClientConn struct {
    91  	t *testing.T
    92  
    93  	tr    *Transport
    94  	fr    *Framer
    95  	cc    *ClientConn
    96  	hooks *testSyncHooks
    97  
    98  	encbuf bytes.Buffer
    99  	enc    *hpack.Encoder
   100  
   101  	roundtrips []*testRoundTrip
   102  
   103  	rerr          error        // returned by Read
   104  	netConnClosed bool         // set when the ClientConn closes the net.Conn
   105  	rbuf          bytes.Buffer // sent to the test conn
   106  	wbuf          bytes.Buffer // sent by the test conn
   107  }
   108  
   109  func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn {
   110  	tc := &testClientConn{
   111  		t:     t,
   112  		tr:    cc.t,
   113  		cc:    cc,
   114  		hooks: cc.t.syncHooks,
   115  	}
   116  	cc.tconn = (*testClientConnNetConn)(tc)
   117  	tc.enc = hpack.NewEncoder(&tc.encbuf)
   118  	tc.fr = NewFramer(&tc.rbuf, &tc.wbuf)
   119  	tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
   120  	tc.fr.SetMaxReadFrameSize(10 << 20)
   121  	t.Cleanup(func() {
   122  		tc.sync()
   123  		if tc.rerr == nil {
   124  			tc.rerr = io.EOF
   125  		}
   126  		tc.sync()
   127  	})
   128  	return tc
   129  }
   130  
   131  func (tc *testClientConn) readClientPreface() {
   132  	tc.t.Helper()
   133  	// Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames.
   134  	buf := make([]byte, len(clientPreface))
   135  	if _, err := io.ReadFull(&tc.wbuf, buf); err != nil {
   136  		tc.t.Fatalf("reading preface: %v", err)
   137  	}
   138  	if !bytes.Equal(buf, clientPreface) {
   139  		tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface)
   140  	}
   141  }
   142  
   143  func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn {
   144  	t.Helper()
   145  
   146  	tt := newTestTransport(t, opts...)
   147  	const singleUse = false
   148  	_, err := tt.tr.newClientConn(nil, singleUse, tt.tr.syncHooks)
   149  	if err != nil {
   150  		t.Fatalf("newClientConn: %v", err)
   151  	}
   152  
   153  	return tt.getConn()
   154  }
   155  
   156  // sync waits for the ClientConn under test to reach a stable state,
   157  // with all goroutines blocked on some input.
   158  func (tc *testClientConn) sync() {
   159  	tc.hooks.waitInactive()
   160  }
   161  
   162  // advance advances synthetic time by a duration.
   163  func (tc *testClientConn) advance(d time.Duration) {
   164  	tc.hooks.advance(d)
   165  	tc.sync()
   166  }
   167  
   168  // hasFrame reports whether a frame is available to be read.
   169  func (tc *testClientConn) hasFrame() bool {
   170  	return tc.wbuf.Len() > 0
   171  }
   172  
   173  // readFrame reads the next frame from the conn.
   174  func (tc *testClientConn) readFrame() Frame {
   175  	if tc.wbuf.Len() == 0 {
   176  		return nil
   177  	}
   178  	fr, err := tc.fr.ReadFrame()
   179  	if err != nil {
   180  		return nil
   181  	}
   182  	return fr
   183  }
   184  
   185  // testClientConnReadFrame reads a frame of a specific type from the conn.
   186  func testClientConnReadFrame[T any](tc *testClientConn) T {
   187  	tc.t.Helper()
   188  	var v T
   189  	fr := tc.readFrame()
   190  	if fr == nil {
   191  		tc.t.Fatalf("got no frame, want frame %T", v)
   192  	}
   193  	v, ok := fr.(T)
   194  	if !ok {
   195  		tc.t.Fatalf("got frame %T, want %T", fr, v)
   196  	}
   197  	return v
   198  }
   199  
   200  // wantFrameType reads the next frame from the conn.
   201  // It produces an error if the frame type is not the expected value.
   202  func (tc *testClientConn) wantFrameType(want FrameType) {
   203  	tc.t.Helper()
   204  	fr := tc.readFrame()
   205  	if fr == nil {
   206  		tc.t.Fatalf("got no frame, want frame %v", want)
   207  	}
   208  	if got := fr.Header().Type; got != want {
   209  		tc.t.Fatalf("got frame %v, want %v", got, want)
   210  	}
   211  }
   212  
   213  // wantUnorderedFrames reads frames from the conn until every condition in want has been satisfied.
   214  //
   215  // want is a list of func(*SomeFrame) bool.
   216  // wantUnorderedFrames will call each func with frames of the appropriate type
   217  // until the func returns true.
   218  // It calls t.Fatal if an unexpected frame is received (no func has that frame type,
   219  // or all funcs with that type have returned true), or if the conn runs out of frames
   220  // with unsatisfied funcs.
   221  //
   222  // Example:
   223  //
   224  //	// Read a SETTINGS frame, and any number of DATA frames for a stream.
   225  //	// The SETTINGS frame may appear anywhere in the sequence.
   226  //	// The last DATA frame must indicate the end of the stream.
   227  //	tc.wantUnorderedFrames(
   228  //		func(f *SettingsFrame) bool {
   229  //			return true
   230  //		},
   231  //		func(f *DataFrame) bool {
   232  //			return f.StreamEnded()
   233  //		},
   234  //	)
   235  func (tc *testClientConn) wantUnorderedFrames(want ...any) {
   236  	tc.t.Helper()
   237  	want = slices.Clone(want)
   238  	seen := 0
   239  frame:
   240  	for seen < len(want) && !tc.t.Failed() {
   241  		fr := tc.readFrame()
   242  		if fr == nil {
   243  			break
   244  		}
   245  		for i, f := range want {
   246  			if f == nil {
   247  				continue
   248  			}
   249  			typ := reflect.TypeOf(f)
   250  			if typ.Kind() != reflect.Func ||
   251  				typ.NumIn() != 1 ||
   252  				typ.NumOut() != 1 ||
   253  				typ.Out(0) != reflect.TypeOf(true) {
   254  				tc.t.Fatalf("expected func(*SomeFrame) bool, got %T", f)
   255  			}
   256  			if typ.In(0) == reflect.TypeOf(fr) {
   257  				out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)})
   258  				if out[0].Bool() {
   259  					want[i] = nil
   260  					seen++
   261  				}
   262  				continue frame
   263  			}
   264  		}
   265  		tc.t.Errorf("got unexpected frame type %T", fr)
   266  	}
   267  	if seen < len(want) {
   268  		for _, f := range want {
   269  			if f == nil {
   270  				continue
   271  			}
   272  			tc.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0))
   273  		}
   274  		tc.t.Fatalf("did not see %v expected frame types", len(want)-seen)
   275  	}
   276  }
   277  
   278  type wantHeader struct {
   279  	streamID  uint32
   280  	endStream bool
   281  	header    http.Header
   282  }
   283  
   284  // wantHeaders reads a HEADERS frame and potential CONTINUATION frames,
   285  // and asserts that they contain the expected headers.
   286  func (tc *testClientConn) wantHeaders(want wantHeader) {
   287  	tc.t.Helper()
   288  	got := testClientConnReadFrame[*MetaHeadersFrame](tc)
   289  	if got, want := got.StreamID, want.streamID; got != want {
   290  		tc.t.Fatalf("got stream ID %v, want %v", got, want)
   291  	}
   292  	if got, want := got.StreamEnded(), want.endStream; got != want {
   293  		tc.t.Fatalf("got stream ended %v, want %v", got, want)
   294  	}
   295  	gotHeader := make(http.Header)
   296  	for _, f := range got.Fields {
   297  		gotHeader[f.Name] = append(gotHeader[f.Name], f.Value)
   298  	}
   299  	for k, v := range want.header {
   300  		if !reflect.DeepEqual(v, gotHeader[k]) {
   301  			tc.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k])
   302  		}
   303  	}
   304  }
   305  
   306  type wantData struct {
   307  	streamID  uint32
   308  	endStream bool
   309  	size      int
   310  }
   311  
   312  // wantData reads zero or more DATA frames, and asserts that they match the expectation.
   313  func (tc *testClientConn) wantData(want wantData) {
   314  	tc.t.Helper()
   315  	gotSize := 0
   316  	gotEndStream := false
   317  	for tc.hasFrame() && !gotEndStream {
   318  		data := testClientConnReadFrame[*DataFrame](tc)
   319  		gotSize += len(data.Data())
   320  		if data.StreamEnded() {
   321  			gotEndStream = true
   322  		}
   323  	}
   324  	if gotSize != want.size {
   325  		tc.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size)
   326  	}
   327  	if gotEndStream != want.endStream {
   328  		tc.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream)
   329  	}
   330  }
   331  
   332  // testRequestBody is a Request.Body for use in tests.
   333  type testRequestBody struct {
   334  	tc *testClientConn
   335  
   336  	// At most one of buf or bytes can be set at any given time:
   337  	buf   bytes.Buffer // specific bytes to read from the body
   338  	bytes int          // body contains this many arbitrary bytes
   339  
   340  	err error // read error (comes after any available bytes)
   341  }
   342  
   343  func (tc *testClientConn) newRequestBody() *testRequestBody {
   344  	b := &testRequestBody{
   345  		tc: tc,
   346  	}
   347  	return b
   348  }
   349  
   350  // Read is called by the ClientConn to read from a request body.
   351  func (b *testRequestBody) Read(p []byte) (n int, _ error) {
   352  	b.tc.cc.syncHooks.blockUntil(func() bool {
   353  		return b.buf.Len() > 0 || b.bytes > 0 || b.err != nil
   354  	})
   355  	switch {
   356  	case b.buf.Len() > 0:
   357  		return b.buf.Read(p)
   358  	case b.bytes > 0:
   359  		if len(p) > b.bytes {
   360  			p = p[:b.bytes]
   361  		}
   362  		b.bytes -= len(p)
   363  		for i := range p {
   364  			p[i] = 'A'
   365  		}
   366  		return len(p), nil
   367  	default:
   368  		return 0, b.err
   369  	}
   370  }
   371  
   372  // Close is called by the ClientConn when it is done reading from a request body.
   373  func (b *testRequestBody) Close() error {
   374  	return nil
   375  }
   376  
   377  // writeBytes adds n arbitrary bytes to the body.
   378  func (b *testRequestBody) writeBytes(n int) {
   379  	b.bytes += n
   380  	b.checkWrite()
   381  	b.tc.sync()
   382  }
   383  
   384  // Write adds bytes to the body.
   385  func (b *testRequestBody) Write(p []byte) (int, error) {
   386  	n, err := b.buf.Write(p)
   387  	b.checkWrite()
   388  	b.tc.sync()
   389  	return n, err
   390  }
   391  
   392  func (b *testRequestBody) checkWrite() {
   393  	if b.bytes > 0 && b.buf.Len() > 0 {
   394  		b.tc.t.Fatalf("can't interleave Write and writeBytes on request body")
   395  	}
   396  	if b.err != nil {
   397  		b.tc.t.Fatalf("can't write to request body after closeWithError")
   398  	}
   399  }
   400  
   401  // closeWithError sets an error which will be returned by Read.
   402  func (b *testRequestBody) closeWithError(err error) {
   403  	b.err = err
   404  	b.tc.sync()
   405  }
   406  
   407  // roundTrip starts a RoundTrip call.
   408  //
   409  // (Note that the RoundTrip won't complete until response headers are received,
   410  // the request times out, or some other terminal condition is reached.)
   411  func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip {
   412  	rt := &testRoundTrip{
   413  		t:     tc.t,
   414  		donec: make(chan struct{}),
   415  	}
   416  	tc.roundtrips = append(tc.roundtrips, rt)
   417  	tc.hooks.newstream = func(cs *clientStream) { rt.cs = cs }
   418  	tc.cc.goRun(func() {
   419  		defer close(rt.donec)
   420  		rt.resp, rt.respErr = tc.cc.RoundTrip(req)
   421  	})
   422  	tc.sync()
   423  	tc.hooks.newstream = nil
   424  
   425  	tc.t.Cleanup(func() {
   426  		if !rt.done() {
   427  			return
   428  		}
   429  		res, _ := rt.result()
   430  		if res != nil {
   431  			res.Body.Close()
   432  		}
   433  	})
   434  
   435  	return rt
   436  }
   437  
   438  func (tc *testClientConn) greet(settings ...Setting) {
   439  	tc.wantFrameType(FrameSettings)
   440  	tc.wantFrameType(FrameWindowUpdate)
   441  	tc.writeSettings(settings...)
   442  	tc.writeSettingsAck()
   443  	tc.wantFrameType(FrameSettings) // acknowledgement
   444  }
   445  
   446  func (tc *testClientConn) writeSettings(settings ...Setting) {
   447  	tc.t.Helper()
   448  	if err := tc.fr.WriteSettings(settings...); err != nil {
   449  		tc.t.Fatal(err)
   450  	}
   451  	tc.sync()
   452  }
   453  
   454  func (tc *testClientConn) writeSettingsAck() {
   455  	tc.t.Helper()
   456  	if err := tc.fr.WriteSettingsAck(); err != nil {
   457  		tc.t.Fatal(err)
   458  	}
   459  	tc.sync()
   460  }
   461  
   462  func (tc *testClientConn) writeData(streamID uint32, endStream bool, data []byte) {
   463  	tc.t.Helper()
   464  	if err := tc.fr.WriteData(streamID, endStream, data); err != nil {
   465  		tc.t.Fatal(err)
   466  	}
   467  	tc.sync()
   468  }
   469  
   470  func (tc *testClientConn) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
   471  	tc.t.Helper()
   472  	if err := tc.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
   473  		tc.t.Fatal(err)
   474  	}
   475  	tc.sync()
   476  }
   477  
   478  // makeHeaderBlockFragment encodes headers in a form suitable for inclusion
   479  // in a HEADERS or CONTINUATION frame.
   480  //
   481  // It takes a list of alernating names and values.
   482  func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte {
   483  	if len(s)%2 != 0 {
   484  		tc.t.Fatalf("uneven list of header name/value pairs")
   485  	}
   486  	tc.encbuf.Reset()
   487  	for i := 0; i < len(s); i += 2 {
   488  		tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]})
   489  	}
   490  	return tc.encbuf.Bytes()
   491  }
   492  
   493  func (tc *testClientConn) writeHeaders(p HeadersFrameParam) {
   494  	tc.t.Helper()
   495  	if err := tc.fr.WriteHeaders(p); err != nil {
   496  		tc.t.Fatal(err)
   497  	}
   498  	tc.sync()
   499  }
   500  
   501  // writeHeadersMode writes header frames, as modified by mode:
   502  //
   503  //   - noHeader: Don't write the header.
   504  //   - oneHeader: Write a single HEADERS frame.
   505  //   - splitHeader: Write a HEADERS frame and CONTINUATION frame.
   506  func (tc *testClientConn) writeHeadersMode(mode headerType, p HeadersFrameParam) {
   507  	tc.t.Helper()
   508  	switch mode {
   509  	case noHeader:
   510  	case oneHeader:
   511  		tc.writeHeaders(p)
   512  	case splitHeader:
   513  		if len(p.BlockFragment) < 2 {
   514  			panic("too small")
   515  		}
   516  		contData := p.BlockFragment[1:]
   517  		contEnd := p.EndHeaders
   518  		p.BlockFragment = p.BlockFragment[:1]
   519  		p.EndHeaders = false
   520  		tc.writeHeaders(p)
   521  		tc.writeContinuation(p.StreamID, contEnd, contData)
   522  	default:
   523  		panic("bogus mode")
   524  	}
   525  }
   526  
   527  func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) {
   528  	tc.t.Helper()
   529  	if err := tc.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil {
   530  		tc.t.Fatal(err)
   531  	}
   532  	tc.sync()
   533  }
   534  
   535  func (tc *testClientConn) writeRSTStream(streamID uint32, code ErrCode) {
   536  	tc.t.Helper()
   537  	if err := tc.fr.WriteRSTStream(streamID, code); err != nil {
   538  		tc.t.Fatal(err)
   539  	}
   540  	tc.sync()
   541  }
   542  
   543  func (tc *testClientConn) writePing(ack bool, data [8]byte) {
   544  	tc.t.Helper()
   545  	if err := tc.fr.WritePing(ack, data); err != nil {
   546  		tc.t.Fatal(err)
   547  	}
   548  	tc.sync()
   549  }
   550  
   551  func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
   552  	tc.t.Helper()
   553  	if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
   554  		tc.t.Fatal(err)
   555  	}
   556  	tc.sync()
   557  }
   558  
   559  func (tc *testClientConn) writeWindowUpdate(streamID, incr uint32) {
   560  	tc.t.Helper()
   561  	if err := tc.fr.WriteWindowUpdate(streamID, incr); err != nil {
   562  		tc.t.Fatal(err)
   563  	}
   564  	tc.sync()
   565  }
   566  
   567  // closeWrite causes the net.Conn used by the ClientConn to return a error
   568  // from Read calls.
   569  func (tc *testClientConn) closeWrite(err error) {
   570  	tc.rerr = err
   571  	tc.sync()
   572  }
   573  
   574  // inflowWindow returns the amount of inbound flow control available for a stream,
   575  // or for the connection if streamID is 0.
   576  func (tc *testClientConn) inflowWindow(streamID uint32) int32 {
   577  	tc.cc.mu.Lock()
   578  	defer tc.cc.mu.Unlock()
   579  	if streamID == 0 {
   580  		return tc.cc.inflow.avail + tc.cc.inflow.unsent
   581  	}
   582  	cs := tc.cc.streams[streamID]
   583  	if cs == nil {
   584  		tc.t.Errorf("no stream with id %v", streamID)
   585  		return -1
   586  	}
   587  	return cs.inflow.avail + cs.inflow.unsent
   588  }
   589  
   590  // testRoundTrip manages a RoundTrip in progress.
   591  type testRoundTrip struct {
   592  	t       *testing.T
   593  	resp    *http.Response
   594  	respErr error
   595  	donec   chan struct{}
   596  	cs      *clientStream
   597  }
   598  
   599  // streamID returns the HTTP/2 stream ID of the request.
   600  func (rt *testRoundTrip) streamID() uint32 {
   601  	if rt.cs == nil {
   602  		panic("stream ID unknown")
   603  	}
   604  	return rt.cs.ID
   605  }
   606  
   607  // done reports whether RoundTrip has returned.
   608  func (rt *testRoundTrip) done() bool {
   609  	select {
   610  	case <-rt.donec:
   611  		return true
   612  	default:
   613  		return false
   614  	}
   615  }
   616  
   617  // result returns the result of the RoundTrip.
   618  func (rt *testRoundTrip) result() (*http.Response, error) {
   619  	t := rt.t
   620  	t.Helper()
   621  	select {
   622  	case <-rt.donec:
   623  	default:
   624  		t.Fatalf("RoundTrip is not done; want it to be")
   625  	}
   626  	return rt.resp, rt.respErr
   627  }
   628  
   629  // response returns the response of a successful RoundTrip.
   630  // If the RoundTrip unexpectedly failed, it calls t.Fatal.
   631  func (rt *testRoundTrip) response() *http.Response {
   632  	t := rt.t
   633  	t.Helper()
   634  	resp, err := rt.result()
   635  	if err != nil {
   636  		t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr)
   637  	}
   638  	if resp == nil {
   639  		t.Fatalf("RoundTrip returned nil *Response and nil error")
   640  	}
   641  	return resp
   642  }
   643  
   644  // err returns the (possibly nil) error result of RoundTrip.
   645  func (rt *testRoundTrip) err() error {
   646  	t := rt.t
   647  	t.Helper()
   648  	_, err := rt.result()
   649  	return err
   650  }
   651  
   652  // wantStatus indicates the expected response StatusCode.
   653  func (rt *testRoundTrip) wantStatus(want int) {
   654  	t := rt.t
   655  	t.Helper()
   656  	if got := rt.response().StatusCode; got != want {
   657  		t.Fatalf("got response status %v, want %v", got, want)
   658  	}
   659  }
   660  
   661  // body reads the contents of the response body.
   662  func (rt *testRoundTrip) readBody() ([]byte, error) {
   663  	t := rt.t
   664  	t.Helper()
   665  	return io.ReadAll(rt.response().Body)
   666  }
   667  
   668  // wantBody indicates the expected response body.
   669  // (Note that this consumes the body.)
   670  func (rt *testRoundTrip) wantBody(want []byte) {
   671  	t := rt.t
   672  	t.Helper()
   673  	got, err := rt.readBody()
   674  	if err != nil {
   675  		t.Fatalf("unexpected error reading response body: %v", err)
   676  	}
   677  	if !bytes.Equal(got, want) {
   678  		t.Fatalf("unexpected response body:\ngot:  %q\nwant: %q", got, want)
   679  	}
   680  }
   681  
   682  // wantHeaders indicates the expected response headers.
   683  func (rt *testRoundTrip) wantHeaders(want http.Header) {
   684  	t := rt.t
   685  	t.Helper()
   686  	res := rt.response()
   687  	if diff := diffHeaders(res.Header, want); diff != "" {
   688  		t.Fatalf("unexpected response headers:\n%v", diff)
   689  	}
   690  }
   691  
   692  // wantTrailers indicates the expected response trailers.
   693  func (rt *testRoundTrip) wantTrailers(want http.Header) {
   694  	t := rt.t
   695  	t.Helper()
   696  	res := rt.response()
   697  	if diff := diffHeaders(res.Trailer, want); diff != "" {
   698  		t.Fatalf("unexpected response trailers:\n%v", diff)
   699  	}
   700  }
   701  
   702  func diffHeaders(got, want http.Header) string {
   703  	// nil and 0-length non-nil are equal.
   704  	if len(got) == 0 && len(want) == 0 {
   705  		return ""
   706  	}
   707  	// We could do a more sophisticated diff here.
   708  	// DeepEqual is good enough for now.
   709  	if reflect.DeepEqual(got, want) {
   710  		return ""
   711  	}
   712  	return fmt.Sprintf("got:  %v\nwant: %v", got, want)
   713  }
   714  
   715  // testClientConnNetConn implements net.Conn.
   716  type testClientConnNetConn testClientConn
   717  
   718  func (nc *testClientConnNetConn) Read(b []byte) (n int, err error) {
   719  	nc.cc.syncHooks.blockUntil(func() bool {
   720  		return nc.rerr != nil || nc.rbuf.Len() > 0
   721  	})
   722  	if nc.rbuf.Len() > 0 {
   723  		return nc.rbuf.Read(b)
   724  	}
   725  	return 0, nc.rerr
   726  }
   727  
   728  func (nc *testClientConnNetConn) Write(b []byte) (n int, err error) {
   729  	return nc.wbuf.Write(b)
   730  }
   731  
   732  func (nc *testClientConnNetConn) Close() error {
   733  	nc.netConnClosed = true
   734  	return nil
   735  }
   736  
   737  func (*testClientConnNetConn) LocalAddr() (_ net.Addr)            { return }
   738  func (*testClientConnNetConn) RemoteAddr() (_ net.Addr)           { return }
   739  func (*testClientConnNetConn) SetDeadline(t time.Time) error      { return nil }
   740  func (*testClientConnNetConn) SetReadDeadline(t time.Time) error  { return nil }
   741  func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil }
   742  
   743  // A testTransport allows testing Transport.RoundTrip against fake servers.
   744  // Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling
   745  // should use testClientConn instead.
   746  type testTransport struct {
   747  	t  *testing.T
   748  	tr *Transport
   749  
   750  	ccs []*testClientConn
   751  }
   752  
   753  func newTestTransport(t *testing.T, opts ...func(*Transport)) *testTransport {
   754  	tr := &Transport{
   755  		syncHooks: newTestSyncHooks(),
   756  	}
   757  	for _, o := range opts {
   758  		o(tr)
   759  	}
   760  
   761  	tt := &testTransport{
   762  		t:  t,
   763  		tr: tr,
   764  	}
   765  	tr.syncHooks.newclientconn = func(cc *ClientConn) {
   766  		tt.ccs = append(tt.ccs, newTestClientConnFromClientConn(t, cc))
   767  	}
   768  
   769  	t.Cleanup(func() {
   770  		tt.sync()
   771  		if len(tt.ccs) > 0 {
   772  			t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs))
   773  		}
   774  		if tt.tr.syncHooks.total != 0 {
   775  			t.Errorf("%v goroutines still running after test completed", tt.tr.syncHooks.total)
   776  		}
   777  	})
   778  
   779  	return tt
   780  }
   781  
   782  func (tt *testTransport) sync() {
   783  	tt.tr.syncHooks.waitInactive()
   784  }
   785  
   786  func (tt *testTransport) advance(d time.Duration) {
   787  	tt.tr.syncHooks.advance(d)
   788  	tt.sync()
   789  }
   790  
   791  func (tt *testTransport) hasConn() bool {
   792  	return len(tt.ccs) > 0
   793  }
   794  
   795  func (tt *testTransport) getConn() *testClientConn {
   796  	tt.t.Helper()
   797  	if len(tt.ccs) == 0 {
   798  		tt.t.Fatalf("no new ClientConns created; wanted one")
   799  	}
   800  	tc := tt.ccs[0]
   801  	tt.ccs = tt.ccs[1:]
   802  	tc.sync()
   803  	tc.readClientPreface()
   804  	return tc
   805  }
   806  
   807  func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip {
   808  	rt := &testRoundTrip{
   809  		t:     tt.t,
   810  		donec: make(chan struct{}),
   811  	}
   812  	tt.tr.syncHooks.goRun(func() {
   813  		defer close(rt.donec)
   814  		rt.resp, rt.respErr = tt.tr.RoundTrip(req)
   815  	})
   816  	tt.sync()
   817  
   818  	tt.t.Cleanup(func() {
   819  		if !rt.done() {
   820  			return
   821  		}
   822  		res, _ := rt.result()
   823  		if res != nil {
   824  			res.Body.Close()
   825  		}
   826  	})
   827  
   828  	return rt
   829  }