golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/quic/conn_test.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build go1.21
     6  
     7  package quic
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/tls"
    13  	"errors"
    14  	"flag"
    15  	"fmt"
    16  	"log/slog"
    17  	"math"
    18  	"net/netip"
    19  	"reflect"
    20  	"strings"
    21  	"testing"
    22  	"time"
    23  
    24  	"golang.org/x/net/quic/qlog"
    25  )
    26  
    27  var (
    28  	testVV  = flag.Bool("vv", false, "even more verbose test output")
    29  	qlogdir = flag.String("qlog", "", "write qlog logs to directory")
    30  )
    31  
    32  func TestConnTestConn(t *testing.T) {
    33  	tc := newTestConn(t, serverSide)
    34  	tc.handshake()
    35  	if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want {
    36  		t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want)
    37  	}
    38  
    39  	ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
    40  		tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
    41  			when = now
    42  		})
    43  		return
    44  	}).result()
    45  	if !ranAt.Equal(tc.endpoint.now) {
    46  		t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now)
    47  	}
    48  	tc.wait()
    49  
    50  	nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2)
    51  	tc.advanceTo(nextTime)
    52  	ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
    53  		tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
    54  			when = now
    55  		})
    56  		return
    57  	}).result()
    58  	if !ranAt.Equal(nextTime) {
    59  		t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime)
    60  	}
    61  	tc.wait()
    62  
    63  	tc.advanceToTimer()
    64  	if got := tc.conn.lifetime.state; got != connStateDone {
    65  		t.Errorf("after advancing to idle timeout, conn state = %v, want done", got)
    66  	}
    67  }
    68  
    69  type testDatagram struct {
    70  	packets    []*testPacket
    71  	paddedSize int
    72  	addr       netip.AddrPort
    73  }
    74  
    75  func (d testDatagram) String() string {
    76  	var b strings.Builder
    77  	fmt.Fprintf(&b, "datagram with %v packets", len(d.packets))
    78  	if d.paddedSize > 0 {
    79  		fmt.Fprintf(&b, " (padded to %v bytes)", d.paddedSize)
    80  	}
    81  	b.WriteString(":")
    82  	for _, p := range d.packets {
    83  		b.WriteString("\n")
    84  		b.WriteString(p.String())
    85  	}
    86  	return b.String()
    87  }
    88  
    89  type testPacket struct {
    90  	ptype             packetType
    91  	header            byte
    92  	version           uint32
    93  	num               packetNumber
    94  	keyPhaseBit       bool
    95  	keyNumber         int
    96  	dstConnID         []byte
    97  	srcConnID         []byte
    98  	token             []byte
    99  	originalDstConnID []byte // used for encoding Retry packets
   100  	frames            []debugFrame
   101  }
   102  
   103  func (p testPacket) String() string {
   104  	var b strings.Builder
   105  	fmt.Fprintf(&b, "  %v %v", p.ptype, p.num)
   106  	if p.version != 0 {
   107  		fmt.Fprintf(&b, " version=%v", p.version)
   108  	}
   109  	if p.srcConnID != nil {
   110  		fmt.Fprintf(&b, " src={%x}", p.srcConnID)
   111  	}
   112  	if p.dstConnID != nil {
   113  		fmt.Fprintf(&b, " dst={%x}", p.dstConnID)
   114  	}
   115  	if p.token != nil {
   116  		fmt.Fprintf(&b, " token={%x}", p.token)
   117  	}
   118  	for _, f := range p.frames {
   119  		fmt.Fprintf(&b, "\n    %v", f)
   120  	}
   121  	return b.String()
   122  }
   123  
   124  // maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test.
   125  const maxTestKeyPhases = 3
   126  
   127  // A testConn is a Conn whose external interactions (sending and receiving packets,
   128  // setting timers) can be manipulated in tests.
   129  type testConn struct {
   130  	t              *testing.T
   131  	conn           *Conn
   132  	endpoint       *testEndpoint
   133  	timer          time.Time
   134  	timerLastFired time.Time
   135  	idlec          chan struct{} // only accessed on the conn's loop
   136  
   137  	// Keys are distinct from the conn's keys,
   138  	// because the test may know about keys before the conn does.
   139  	// For example, when sending a datagram with coalesced
   140  	// Initial and Handshake packets to a client conn,
   141  	// we use Handshake keys to encrypt the packet.
   142  	// The client only acquires those keys when it processes
   143  	// the Initial packet.
   144  	keysInitial   fixedKeyPair
   145  	keysHandshake fixedKeyPair
   146  	rkeyAppData   test1RTTKeys
   147  	wkeyAppData   test1RTTKeys
   148  	rsecrets      [numberSpaceCount]keySecret
   149  	wsecrets      [numberSpaceCount]keySecret
   150  
   151  	// testConn uses a test hook to snoop on the conn's TLS events.
   152  	// CRYPTO data produced by the conn's QUICConn is placed in
   153  	// cryptoDataOut.
   154  	//
   155  	// The peerTLSConn is is a QUICConn representing the peer.
   156  	// CRYPTO data produced by the conn is written to peerTLSConn,
   157  	// and data produced by peerTLSConn is placed in cryptoDataIn.
   158  	cryptoDataOut map[tls.QUICEncryptionLevel][]byte
   159  	cryptoDataIn  map[tls.QUICEncryptionLevel][]byte
   160  	peerTLSConn   *tls.QUICConn
   161  
   162  	// Information about the conn's (fake) peer.
   163  	peerConnID        []byte                         // source conn id of peer's packets
   164  	peerNextPacketNum [numberSpaceCount]packetNumber // next packet number to use
   165  
   166  	// Datagrams, packets, and frames sent by the conn,
   167  	// but not yet processed by the test.
   168  	sentDatagrams [][]byte
   169  	sentPackets   []*testPacket
   170  	sentFrames    []debugFrame
   171  	lastDatagram  *testDatagram
   172  	lastPacket    *testPacket
   173  
   174  	recvDatagram chan *datagram
   175  
   176  	// Transport parameters sent by the conn.
   177  	sentTransportParameters *transportParameters
   178  
   179  	// Frame types to ignore in tests.
   180  	ignoreFrames map[byte]bool
   181  
   182  	// Values to set in packets sent to the conn.
   183  	sendKeyNumber   int
   184  	sendKeyPhaseBit bool
   185  
   186  	asyncTestState
   187  }
   188  
   189  type test1RTTKeys struct {
   190  	hdr headerKey
   191  	pkt [maxTestKeyPhases]packetKey
   192  }
   193  
   194  type keySecret struct {
   195  	suite  uint16
   196  	secret []byte
   197  }
   198  
   199  // newTestConn creates a Conn for testing.
   200  //
   201  // The Conn's event loop is controlled by the test,
   202  // allowing test code to access Conn state directly
   203  // by first ensuring the loop goroutine is idle.
   204  func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
   205  	t.Helper()
   206  	config := &Config{
   207  		TLSConfig:         newTestTLSConfig(side),
   208  		StatelessResetKey: testStatelessResetKey,
   209  		QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
   210  			Level: QLogLevelFrame,
   211  			Dir:   *qlogdir,
   212  		})),
   213  	}
   214  	var cids newServerConnIDs
   215  	if side == serverSide {
   216  		// The initial connection ID for the server is chosen by the client.
   217  		cids.srcConnID = testPeerConnID(0)
   218  		cids.dstConnID = testPeerConnID(-1)
   219  		cids.originalDstConnID = cids.dstConnID
   220  	}
   221  	var configTransportParams []func(*transportParameters)
   222  	var configTestConn []func(*testConn)
   223  	for _, o := range opts {
   224  		switch o := o.(type) {
   225  		case func(*Config):
   226  			o(config)
   227  		case func(*tls.Config):
   228  			o(config.TLSConfig)
   229  		case func(cids *newServerConnIDs):
   230  			o(&cids)
   231  		case func(p *transportParameters):
   232  			configTransportParams = append(configTransportParams, o)
   233  		case func(p *testConn):
   234  			configTestConn = append(configTestConn, o)
   235  		default:
   236  			t.Fatalf("unknown newTestConn option %T", o)
   237  		}
   238  	}
   239  
   240  	endpoint := newTestEndpoint(t, config)
   241  	endpoint.configTransportParams = configTransportParams
   242  	endpoint.configTestConn = configTestConn
   243  	conn, err := endpoint.e.newConn(
   244  		endpoint.now,
   245  		config,
   246  		side,
   247  		cids,
   248  		"",
   249  		netip.MustParseAddrPort("127.0.0.1:443"))
   250  	if err != nil {
   251  		t.Fatal(err)
   252  	}
   253  	tc := endpoint.conns[conn]
   254  	tc.wait()
   255  	return tc
   256  }
   257  
   258  func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn {
   259  	t.Helper()
   260  	tc := &testConn{
   261  		t:          t,
   262  		endpoint:   endpoint,
   263  		conn:       conn,
   264  		peerConnID: testPeerConnID(0),
   265  		ignoreFrames: map[byte]bool{
   266  			frameTypePadding: true, // ignore PADDING by default
   267  		},
   268  		cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte),
   269  		cryptoDataIn:  make(map[tls.QUICEncryptionLevel][]byte),
   270  		recvDatagram:  make(chan *datagram),
   271  	}
   272  	t.Cleanup(tc.cleanup)
   273  	for _, f := range endpoint.configTestConn {
   274  		f(tc)
   275  	}
   276  	conn.testHooks = (*testConnHooks)(tc)
   277  
   278  	if endpoint.peerTLSConn != nil {
   279  		tc.peerTLSConn = endpoint.peerTLSConn
   280  		endpoint.peerTLSConn = nil
   281  		return tc
   282  	}
   283  
   284  	peerProvidedParams := defaultTransportParameters()
   285  	peerProvidedParams.initialSrcConnID = testPeerConnID(0)
   286  	if conn.side == clientSide {
   287  		peerProvidedParams.originalDstConnID = testLocalConnID(-1)
   288  	}
   289  	for _, f := range endpoint.configTransportParams {
   290  		f(&peerProvidedParams)
   291  	}
   292  
   293  	peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(conn.side.peer())}
   294  	if conn.side == clientSide {
   295  		tc.peerTLSConn = tls.QUICServer(peerQUICConfig)
   296  	} else {
   297  		tc.peerTLSConn = tls.QUICClient(peerQUICConfig)
   298  	}
   299  	tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
   300  	tc.peerTLSConn.Start(context.Background())
   301  	t.Cleanup(func() {
   302  		tc.peerTLSConn.Close()
   303  	})
   304  
   305  	return tc
   306  }
   307  
   308  // advance causes time to pass.
   309  func (tc *testConn) advance(d time.Duration) {
   310  	tc.t.Helper()
   311  	tc.endpoint.advance(d)
   312  }
   313  
   314  // advanceTo sets the current time.
   315  func (tc *testConn) advanceTo(now time.Time) {
   316  	tc.t.Helper()
   317  	tc.endpoint.advanceTo(now)
   318  }
   319  
   320  // advanceToTimer sets the current time to the time of the Conn's next timer event.
   321  func (tc *testConn) advanceToTimer() {
   322  	if tc.timer.IsZero() {
   323  		tc.t.Fatalf("advancing to timer, but timer is not set")
   324  	}
   325  	tc.advanceTo(tc.timer)
   326  }
   327  
   328  func (tc *testConn) timerDelay() time.Duration {
   329  	if tc.timer.IsZero() {
   330  		return math.MaxInt64 // infinite
   331  	}
   332  	if tc.timer.Before(tc.endpoint.now) {
   333  		return 0
   334  	}
   335  	return tc.timer.Sub(tc.endpoint.now)
   336  }
   337  
   338  const infiniteDuration = time.Duration(math.MaxInt64)
   339  
   340  // timeUntilEvent returns the amount of time until the next connection event.
   341  func (tc *testConn) timeUntilEvent() time.Duration {
   342  	if tc.timer.IsZero() {
   343  		return infiniteDuration
   344  	}
   345  	if tc.timer.Before(tc.endpoint.now) {
   346  		return 0
   347  	}
   348  	return tc.timer.Sub(tc.endpoint.now)
   349  }
   350  
   351  // wait blocks until the conn becomes idle.
   352  // The conn is idle when it is blocked waiting for a packet to arrive or a timer to expire.
   353  // Tests shouldn't need to call wait directly.
   354  // testConn methods that wake the Conn event loop will call wait for them.
   355  func (tc *testConn) wait() {
   356  	tc.t.Helper()
   357  	idlec := make(chan struct{})
   358  	fail := false
   359  	tc.conn.sendMsg(func(now time.Time, c *Conn) {
   360  		if tc.idlec != nil {
   361  			tc.t.Errorf("testConn.wait called concurrently")
   362  			fail = true
   363  			close(idlec)
   364  		} else {
   365  			// nextMessage will close idlec.
   366  			tc.idlec = idlec
   367  		}
   368  	})
   369  	select {
   370  	case <-idlec:
   371  	case <-tc.conn.donec:
   372  		// We may have async ops that can proceed now that the conn is done.
   373  		tc.wakeAsync()
   374  	}
   375  	if fail {
   376  		panic(fail)
   377  	}
   378  }
   379  
   380  func (tc *testConn) cleanup() {
   381  	if tc.conn == nil {
   382  		return
   383  	}
   384  	tc.conn.exit()
   385  	<-tc.conn.donec
   386  }
   387  
   388  func (tc *testConn) acceptStream() *Stream {
   389  	tc.t.Helper()
   390  	s, err := tc.conn.AcceptStream(canceledContext())
   391  	if err != nil {
   392  		tc.t.Fatalf("conn.AcceptStream() = %v, want stream", err)
   393  	}
   394  	s.SetReadContext(canceledContext())
   395  	s.SetWriteContext(canceledContext())
   396  	return s
   397  }
   398  
   399  func logDatagram(t *testing.T, text string, d *testDatagram) {
   400  	t.Helper()
   401  	if !*testVV {
   402  		return
   403  	}
   404  	pad := ""
   405  	if d.paddedSize > 0 {
   406  		pad = fmt.Sprintf(" (padded to %v)", d.paddedSize)
   407  	}
   408  	t.Logf("%v datagram%v", text, pad)
   409  	for _, p := range d.packets {
   410  		var s string
   411  		switch p.ptype {
   412  		case packetType1RTT:
   413  			s = fmt.Sprintf("  %v pnum=%v", p.ptype, p.num)
   414  		default:
   415  			s = fmt.Sprintf("  %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
   416  		}
   417  		if p.token != nil {
   418  			s += fmt.Sprintf(" token={%x}", p.token)
   419  		}
   420  		if p.keyPhaseBit {
   421  			s += fmt.Sprintf(" KeyPhase")
   422  		}
   423  		if p.keyNumber != 0 {
   424  			s += fmt.Sprintf(" keynum=%v", p.keyNumber)
   425  		}
   426  		t.Log(s)
   427  		for _, f := range p.frames {
   428  			t.Logf("    %v", f)
   429  		}
   430  	}
   431  }
   432  
   433  // write sends the Conn a datagram.
   434  func (tc *testConn) write(d *testDatagram) {
   435  	tc.t.Helper()
   436  	tc.endpoint.writeDatagram(d)
   437  }
   438  
   439  // writeFrame sends the Conn a datagram containing the given frames.
   440  func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
   441  	tc.t.Helper()
   442  	space := spaceForPacketType(ptype)
   443  	dstConnID := tc.conn.connIDState.local[0].cid
   444  	if tc.conn.connIDState.local[0].seq == -1 && ptype != packetTypeInitial {
   445  		// Only use the transient connection ID in Initial packets.
   446  		dstConnID = tc.conn.connIDState.local[1].cid
   447  	}
   448  	d := &testDatagram{
   449  		packets: []*testPacket{{
   450  			ptype:       ptype,
   451  			num:         tc.peerNextPacketNum[space],
   452  			keyNumber:   tc.sendKeyNumber,
   453  			keyPhaseBit: tc.sendKeyPhaseBit,
   454  			frames:      frames,
   455  			version:     quicVersion1,
   456  			dstConnID:   dstConnID,
   457  			srcConnID:   tc.peerConnID,
   458  		}},
   459  		addr: tc.conn.peerAddr,
   460  	}
   461  	if ptype == packetTypeInitial && tc.conn.side == serverSide {
   462  		d.paddedSize = 1200
   463  	}
   464  	tc.write(d)
   465  }
   466  
   467  // writeAckForAll sends the Conn a datagram containing an ack for all packets up to the
   468  // last one received.
   469  func (tc *testConn) writeAckForAll() {
   470  	tc.t.Helper()
   471  	if tc.lastPacket == nil {
   472  		return
   473  	}
   474  	tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
   475  		ranges: []i64range[packetNumber]{{0, tc.lastPacket.num + 1}},
   476  	})
   477  }
   478  
   479  // writeAckForLatest sends the Conn a datagram containing an ack for the
   480  // most recent packet received.
   481  func (tc *testConn) writeAckForLatest() {
   482  	tc.t.Helper()
   483  	if tc.lastPacket == nil {
   484  		return
   485  	}
   486  	tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
   487  		ranges: []i64range[packetNumber]{{tc.lastPacket.num, tc.lastPacket.num + 1}},
   488  	})
   489  }
   490  
   491  // ignoreFrame hides frames of the given type sent by the Conn.
   492  func (tc *testConn) ignoreFrame(frameType byte) {
   493  	tc.ignoreFrames[frameType] = true
   494  }
   495  
   496  // readDatagram reads the next datagram sent by the Conn.
   497  // It returns nil if the Conn has no more datagrams to send at this time.
   498  func (tc *testConn) readDatagram() *testDatagram {
   499  	tc.t.Helper()
   500  	tc.wait()
   501  	tc.sentPackets = nil
   502  	tc.sentFrames = nil
   503  	buf := tc.endpoint.read()
   504  	if buf == nil {
   505  		return nil
   506  	}
   507  	d := parseTestDatagram(tc.t, tc.endpoint, tc, buf)
   508  	// Log the datagram before removing ignored frames.
   509  	// When things go wrong, it's useful to see all the frames.
   510  	logDatagram(tc.t, "-> conn under test sends", d)
   511  	typeForFrame := func(f debugFrame) byte {
   512  		// This is very clunky, and points at a problem
   513  		// in how we specify what frames to ignore in tests.
   514  		//
   515  		// We mark frames to ignore using the frame type,
   516  		// but we've got a debugFrame data structure here.
   517  		// Perhaps we should be ignoring frames by debugFrame
   518  		// type instead: tc.ignoreFrame[debugFrameAck]().
   519  		switch f := f.(type) {
   520  		case debugFramePadding:
   521  			return frameTypePadding
   522  		case debugFramePing:
   523  			return frameTypePing
   524  		case debugFrameAck:
   525  			return frameTypeAck
   526  		case debugFrameResetStream:
   527  			return frameTypeResetStream
   528  		case debugFrameStopSending:
   529  			return frameTypeStopSending
   530  		case debugFrameCrypto:
   531  			return frameTypeCrypto
   532  		case debugFrameNewToken:
   533  			return frameTypeNewToken
   534  		case debugFrameStream:
   535  			return frameTypeStreamBase
   536  		case debugFrameMaxData:
   537  			return frameTypeMaxData
   538  		case debugFrameMaxStreamData:
   539  			return frameTypeMaxStreamData
   540  		case debugFrameMaxStreams:
   541  			if f.streamType == bidiStream {
   542  				return frameTypeMaxStreamsBidi
   543  			} else {
   544  				return frameTypeMaxStreamsUni
   545  			}
   546  		case debugFrameDataBlocked:
   547  			return frameTypeDataBlocked
   548  		case debugFrameStreamDataBlocked:
   549  			return frameTypeStreamDataBlocked
   550  		case debugFrameStreamsBlocked:
   551  			if f.streamType == bidiStream {
   552  				return frameTypeStreamsBlockedBidi
   553  			} else {
   554  				return frameTypeStreamsBlockedUni
   555  			}
   556  		case debugFrameNewConnectionID:
   557  			return frameTypeNewConnectionID
   558  		case debugFrameRetireConnectionID:
   559  			return frameTypeRetireConnectionID
   560  		case debugFramePathChallenge:
   561  			return frameTypePathChallenge
   562  		case debugFramePathResponse:
   563  			return frameTypePathResponse
   564  		case debugFrameConnectionCloseTransport:
   565  			return frameTypeConnectionCloseTransport
   566  		case debugFrameConnectionCloseApplication:
   567  			return frameTypeConnectionCloseApplication
   568  		case debugFrameHandshakeDone:
   569  			return frameTypeHandshakeDone
   570  		}
   571  		panic(fmt.Errorf("unhandled frame type %T", f))
   572  	}
   573  	for _, p := range d.packets {
   574  		var frames []debugFrame
   575  		for _, f := range p.frames {
   576  			if !tc.ignoreFrames[typeForFrame(f)] {
   577  				frames = append(frames, f)
   578  			}
   579  		}
   580  		p.frames = frames
   581  	}
   582  	tc.lastDatagram = d
   583  	return d
   584  }
   585  
   586  // readPacket reads the next packet sent by the Conn.
   587  // It returns nil if the Conn has no more packets to send at this time.
   588  func (tc *testConn) readPacket() *testPacket {
   589  	tc.t.Helper()
   590  	for len(tc.sentPackets) == 0 {
   591  		d := tc.readDatagram()
   592  		if d == nil {
   593  			return nil
   594  		}
   595  		for _, p := range d.packets {
   596  			if len(p.frames) == 0 {
   597  				tc.lastPacket = p
   598  				continue
   599  			}
   600  			tc.sentPackets = append(tc.sentPackets, p)
   601  		}
   602  	}
   603  	p := tc.sentPackets[0]
   604  	tc.sentPackets = tc.sentPackets[1:]
   605  	tc.lastPacket = p
   606  	return p
   607  }
   608  
   609  // readFrame reads the next frame sent by the Conn.
   610  // It returns nil if the Conn has no more frames to send at this time.
   611  func (tc *testConn) readFrame() (debugFrame, packetType) {
   612  	tc.t.Helper()
   613  	for len(tc.sentFrames) == 0 {
   614  		p := tc.readPacket()
   615  		if p == nil {
   616  			return nil, packetTypeInvalid
   617  		}
   618  		tc.sentFrames = p.frames
   619  	}
   620  	f := tc.sentFrames[0]
   621  	tc.sentFrames = tc.sentFrames[1:]
   622  	return f, tc.lastPacket.ptype
   623  }
   624  
   625  // wantDatagram indicates that we expect the Conn to send a datagram.
   626  func (tc *testConn) wantDatagram(expectation string, want *testDatagram) {
   627  	tc.t.Helper()
   628  	got := tc.readDatagram()
   629  	if !datagramEqual(got, want) {
   630  		tc.t.Fatalf("%v:\ngot datagram:  %v\nwant datagram: %v", expectation, got, want)
   631  	}
   632  }
   633  
   634  func datagramEqual(a, b *testDatagram) bool {
   635  	if a == nil && b == nil {
   636  		return true
   637  	}
   638  	if a == nil || b == nil {
   639  		return false
   640  	}
   641  	if a.paddedSize != b.paddedSize ||
   642  		a.addr != b.addr ||
   643  		len(a.packets) != len(b.packets) {
   644  		return false
   645  	}
   646  	for i := range a.packets {
   647  		if !packetEqual(a.packets[i], b.packets[i]) {
   648  			return false
   649  		}
   650  	}
   651  	return true
   652  }
   653  
   654  // wantPacket indicates that we expect the Conn to send a packet.
   655  func (tc *testConn) wantPacket(expectation string, want *testPacket) {
   656  	tc.t.Helper()
   657  	got := tc.readPacket()
   658  	if !packetEqual(got, want) {
   659  		tc.t.Fatalf("%v:\ngot packet:  %v\nwant packet: %v", expectation, got, want)
   660  	}
   661  }
   662  
   663  func packetEqual(a, b *testPacket) bool {
   664  	if a == nil && b == nil {
   665  		return true
   666  	}
   667  	if a == nil || b == nil {
   668  		return false
   669  	}
   670  	ac := *a
   671  	ac.frames = nil
   672  	ac.header = 0
   673  	bc := *b
   674  	bc.frames = nil
   675  	bc.header = 0
   676  	if !reflect.DeepEqual(ac, bc) {
   677  		return false
   678  	}
   679  	if len(a.frames) != len(b.frames) {
   680  		return false
   681  	}
   682  	for i := range a.frames {
   683  		if !frameEqual(a.frames[i], b.frames[i]) {
   684  			return false
   685  		}
   686  	}
   687  	return true
   688  }
   689  
   690  // wantFrame indicates that we expect the Conn to send a frame.
   691  func (tc *testConn) wantFrame(expectation string, wantType packetType, want debugFrame) {
   692  	tc.t.Helper()
   693  	got, gotType := tc.readFrame()
   694  	if got == nil {
   695  		tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
   696  	}
   697  	if gotType != wantType {
   698  		tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame:  %v", expectation, gotType, wantType, got)
   699  	}
   700  	if !frameEqual(got, want) {
   701  		tc.t.Fatalf("%v:\ngot frame:  %v\nwant frame: %v", expectation, got, want)
   702  	}
   703  }
   704  
   705  func frameEqual(a, b debugFrame) bool {
   706  	switch af := a.(type) {
   707  	case debugFrameConnectionCloseTransport:
   708  		bf, ok := b.(debugFrameConnectionCloseTransport)
   709  		return ok && af.code == bf.code
   710  	}
   711  	return reflect.DeepEqual(a, b)
   712  }
   713  
   714  // wantFrameType indicates that we expect the Conn to send a frame,
   715  // although we don't care about the contents.
   716  func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) {
   717  	tc.t.Helper()
   718  	got, gotType := tc.readFrame()
   719  	if got == nil {
   720  		tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
   721  	}
   722  	if gotType != wantType {
   723  		tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame:  %v", expectation, gotType, wantType, got)
   724  	}
   725  	if reflect.TypeOf(got) != reflect.TypeOf(want) {
   726  		tc.t.Fatalf("%v:\ngot frame:  %v\nwant frame of type: %v", expectation, got, want)
   727  	}
   728  }
   729  
   730  // wantIdle indicates that we expect the Conn to not send any more frames.
   731  func (tc *testConn) wantIdle(expectation string) {
   732  	tc.t.Helper()
   733  	switch {
   734  	case len(tc.sentFrames) > 0:
   735  		tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentFrames[0])
   736  	case len(tc.sentPackets) > 0:
   737  		tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentPackets[0])
   738  	}
   739  	if f, _ := tc.readFrame(); f != nil {
   740  		tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, f)
   741  	}
   742  }
   743  
   744  func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte {
   745  	t.Helper()
   746  	var w packetWriter
   747  	w.reset(1200)
   748  	var pnumMaxAcked packetNumber
   749  	switch p.ptype {
   750  	case packetTypeRetry:
   751  		return encodeRetryPacket(p.originalDstConnID, retryPacket{
   752  			srcConnID: p.srcConnID,
   753  			dstConnID: p.dstConnID,
   754  			token:     p.token,
   755  		})
   756  	case packetType1RTT:
   757  		w.start1RTTPacket(p.num, pnumMaxAcked, p.dstConnID)
   758  	default:
   759  		w.startProtectedLongHeaderPacket(pnumMaxAcked, longPacket{
   760  			ptype:     p.ptype,
   761  			version:   p.version,
   762  			num:       p.num,
   763  			dstConnID: p.dstConnID,
   764  			srcConnID: p.srcConnID,
   765  			extra:     p.token,
   766  		})
   767  	}
   768  	for _, f := range p.frames {
   769  		f.write(&w)
   770  	}
   771  	w.appendPaddingTo(pad)
   772  	if p.ptype != packetType1RTT {
   773  		var k fixedKeys
   774  		if tc == nil {
   775  			if p.ptype == packetTypeInitial {
   776  				k = initialKeys(p.dstConnID, serverSide).r
   777  			} else {
   778  				t.Fatalf("sending %v packet with no conn", p.ptype)
   779  			}
   780  		} else {
   781  			switch p.ptype {
   782  			case packetTypeInitial:
   783  				k = tc.keysInitial.w
   784  			case packetTypeHandshake:
   785  				k = tc.keysHandshake.w
   786  			}
   787  		}
   788  		if !k.isSet() {
   789  			t.Fatalf("sending %v packet with no write key", p.ptype)
   790  		}
   791  		w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{
   792  			ptype:     p.ptype,
   793  			version:   p.version,
   794  			num:       p.num,
   795  			dstConnID: p.dstConnID,
   796  			srcConnID: p.srcConnID,
   797  			extra:     p.token,
   798  		})
   799  	} else {
   800  		if tc == nil || !tc.wkeyAppData.hdr.isSet() {
   801  			t.Fatalf("sending 1-RTT packet with no write key")
   802  		}
   803  		// Somewhat hackish: Generate a temporary updatingKeyPair that will
   804  		// always use our desired key phase.
   805  		k := &updatingKeyPair{
   806  			w: updatingKeys{
   807  				hdr: tc.wkeyAppData.hdr,
   808  				pkt: [2]packetKey{
   809  					tc.wkeyAppData.pkt[p.keyNumber],
   810  					tc.wkeyAppData.pkt[p.keyNumber],
   811  				},
   812  			},
   813  			updateAfter: maxPacketNumber,
   814  		}
   815  		if p.keyPhaseBit {
   816  			k.phase |= keyPhaseBit
   817  		}
   818  		w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k)
   819  	}
   820  	return w.datagram()
   821  }
   822  
   823  func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram {
   824  	t.Helper()
   825  	bufSize := len(buf)
   826  	d := &testDatagram{}
   827  	size := len(buf)
   828  	for len(buf) > 0 {
   829  		if buf[0] == 0 {
   830  			d.paddedSize = bufSize
   831  			break
   832  		}
   833  		ptype := getPacketType(buf)
   834  		switch ptype {
   835  		case packetTypeRetry:
   836  			retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID)
   837  			if !ok {
   838  				t.Fatalf("could not parse %v packet", ptype)
   839  			}
   840  			return &testDatagram{
   841  				packets: []*testPacket{{
   842  					ptype:     packetTypeRetry,
   843  					dstConnID: retry.dstConnID,
   844  					srcConnID: retry.srcConnID,
   845  					token:     retry.token,
   846  				}},
   847  			}
   848  		case packetTypeInitial, packetTypeHandshake:
   849  			var k fixedKeys
   850  			if tc == nil {
   851  				if ptype == packetTypeInitial {
   852  					p, _ := parseGenericLongHeaderPacket(buf)
   853  					k = initialKeys(p.srcConnID, serverSide).w
   854  				} else {
   855  					t.Fatalf("reading %v packet with no conn", ptype)
   856  				}
   857  			} else {
   858  				switch ptype {
   859  				case packetTypeInitial:
   860  					k = tc.keysInitial.r
   861  				case packetTypeHandshake:
   862  					k = tc.keysHandshake.r
   863  				}
   864  			}
   865  			if !k.isSet() {
   866  				t.Fatalf("reading %v packet with no read key", ptype)
   867  			}
   868  			var pnumMax packetNumber // TODO: Track packet numbers.
   869  			p, n := parseLongHeaderPacket(buf, k, pnumMax)
   870  			if n < 0 {
   871  				t.Fatalf("packet parse error")
   872  			}
   873  			frames, err := parseTestFrames(t, p.payload)
   874  			if err != nil {
   875  				t.Fatal(err)
   876  			}
   877  			var token []byte
   878  			if ptype == packetTypeInitial && len(p.extra) > 0 {
   879  				token = p.extra
   880  			}
   881  			d.packets = append(d.packets, &testPacket{
   882  				ptype:     p.ptype,
   883  				header:    buf[0],
   884  				version:   p.version,
   885  				num:       p.num,
   886  				dstConnID: p.dstConnID,
   887  				srcConnID: p.srcConnID,
   888  				token:     token,
   889  				frames:    frames,
   890  			})
   891  			buf = buf[n:]
   892  		case packetType1RTT:
   893  			if tc == nil || !tc.rkeyAppData.hdr.isSet() {
   894  				t.Fatalf("reading 1-RTT packet with no read key")
   895  			}
   896  			var pnumMax packetNumber // TODO: Track packet numbers.
   897  			pnumOff := 1 + len(tc.peerConnID)
   898  			// Try unprotecting the packet with the first maxTestKeyPhases keys.
   899  			var phase int
   900  			var pnum packetNumber
   901  			var hdr []byte
   902  			var pay []byte
   903  			var err error
   904  			for phase = 0; phase < maxTestKeyPhases; phase++ {
   905  				b := append([]byte{}, buf...)
   906  				hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax)
   907  				if err != nil {
   908  					t.Fatalf("1-RTT packet header parse error")
   909  				}
   910  				k := tc.rkeyAppData.pkt[phase]
   911  				pay, err = k.unprotect(hdr, pay, pnum)
   912  				if err == nil {
   913  					break
   914  				}
   915  			}
   916  			if err != nil {
   917  				t.Fatalf("1-RTT packet payload parse error")
   918  			}
   919  			frames, err := parseTestFrames(t, pay)
   920  			if err != nil {
   921  				t.Fatal(err)
   922  			}
   923  			d.packets = append(d.packets, &testPacket{
   924  				ptype:       packetType1RTT,
   925  				header:      hdr[0],
   926  				num:         pnum,
   927  				dstConnID:   hdr[1:][:len(tc.peerConnID)],
   928  				keyPhaseBit: hdr[0]&keyPhaseBit != 0,
   929  				keyNumber:   phase,
   930  				frames:      frames,
   931  			})
   932  			buf = buf[len(buf):]
   933  		default:
   934  			t.Fatalf("unhandled packet type %v", ptype)
   935  		}
   936  	}
   937  	// This is rather hackish: If the last frame in the last packet
   938  	// in the datagram is PADDING, then remove it and record
   939  	// the padded size in the testDatagram.paddedSize.
   940  	//
   941  	// This makes it easier to write a test that expects a datagram
   942  	// padded to 1200 bytes.
   943  	if len(d.packets) > 0 && len(d.packets[len(d.packets)-1].frames) > 0 {
   944  		p := d.packets[len(d.packets)-1]
   945  		f := p.frames[len(p.frames)-1]
   946  		if _, ok := f.(debugFramePadding); ok {
   947  			p.frames = p.frames[:len(p.frames)-1]
   948  			d.paddedSize = size
   949  		}
   950  	}
   951  	return d
   952  }
   953  
   954  func parseTestFrames(t *testing.T, payload []byte) ([]debugFrame, error) {
   955  	t.Helper()
   956  	var frames []debugFrame
   957  	for len(payload) > 0 {
   958  		f, n := parseDebugFrame(payload)
   959  		if n < 0 {
   960  			return nil, errors.New("error parsing frames")
   961  		}
   962  		frames = append(frames, f)
   963  		payload = payload[n:]
   964  	}
   965  	return frames, nil
   966  }
   967  
   968  func spaceForPacketType(ptype packetType) numberSpace {
   969  	switch ptype {
   970  	case packetTypeInitial:
   971  		return initialSpace
   972  	case packetType0RTT:
   973  		panic("TODO: packetType0RTT")
   974  	case packetTypeHandshake:
   975  		return handshakeSpace
   976  	case packetTypeRetry:
   977  		panic("retry packets have no number space")
   978  	case packetType1RTT:
   979  		return appDataSpace
   980  	}
   981  	panic("unknown packet type")
   982  }
   983  
   984  // testConnHooks implements connTestHooks.
   985  type testConnHooks testConn
   986  
   987  func (tc *testConnHooks) init() {
   988  	tc.conn.keysAppData.updateAfter = maxPacketNumber // disable key updates
   989  	tc.keysInitial.r = tc.conn.keysInitial.w
   990  	tc.keysInitial.w = tc.conn.keysInitial.r
   991  	if tc.conn.side == serverSide {
   992  		tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc))
   993  	}
   994  }
   995  
   996  // handleTLSEvent processes TLS events generated by
   997  // the connection under test's tls.QUICConn.
   998  //
   999  // We maintain a second tls.QUICConn representing the peer,
  1000  // and feed the TLS handshake data into it.
  1001  //
  1002  // We stash TLS handshake data from both sides in the testConn,
  1003  // where it can be used by tests.
  1004  //
  1005  // We snoop packet protection keys out of the tls.QUICConns,
  1006  // and verify that both sides of the connection are getting
  1007  // matching keys.
  1008  func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
  1009  	checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) {
  1010  		var space numberSpace
  1011  		switch {
  1012  		case e.Level == tls.QUICEncryptionLevelHandshake:
  1013  			space = handshakeSpace
  1014  		case e.Level == tls.QUICEncryptionLevelApplication:
  1015  			space = appDataSpace
  1016  		default:
  1017  			tc.t.Errorf("unexpected encryption level %v", e.Level)
  1018  			return
  1019  		}
  1020  		if secrets[space].secret == nil {
  1021  			secrets[space].suite = e.Suite
  1022  			secrets[space].secret = append([]byte{}, e.Data...)
  1023  		} else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) {
  1024  			tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level)
  1025  		}
  1026  	}
  1027  	setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) {
  1028  		k.hdr.init(suite, secret)
  1029  		for i := 0; i < len(k.pkt); i++ {
  1030  			k.pkt[i].init(suite, secret)
  1031  			secret = updateSecret(suite, secret)
  1032  		}
  1033  	}
  1034  	switch e.Kind {
  1035  	case tls.QUICSetReadSecret:
  1036  		checkKey("write", &tc.wsecrets, e)
  1037  		switch e.Level {
  1038  		case tls.QUICEncryptionLevelHandshake:
  1039  			tc.keysHandshake.w.init(e.Suite, e.Data)
  1040  		case tls.QUICEncryptionLevelApplication:
  1041  			setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
  1042  		}
  1043  	case tls.QUICSetWriteSecret:
  1044  		checkKey("read", &tc.rsecrets, e)
  1045  		switch e.Level {
  1046  		case tls.QUICEncryptionLevelHandshake:
  1047  			tc.keysHandshake.r.init(e.Suite, e.Data)
  1048  		case tls.QUICEncryptionLevelApplication:
  1049  			setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
  1050  		}
  1051  	case tls.QUICWriteData:
  1052  		tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
  1053  		tc.peerTLSConn.HandleData(e.Level, e.Data)
  1054  	}
  1055  	for {
  1056  		e := tc.peerTLSConn.NextEvent()
  1057  		switch e.Kind {
  1058  		case tls.QUICNoEvent:
  1059  			return
  1060  		case tls.QUICSetReadSecret:
  1061  			checkKey("write", &tc.rsecrets, e)
  1062  			switch e.Level {
  1063  			case tls.QUICEncryptionLevelHandshake:
  1064  				tc.keysHandshake.r.init(e.Suite, e.Data)
  1065  			case tls.QUICEncryptionLevelApplication:
  1066  				setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
  1067  			}
  1068  		case tls.QUICSetWriteSecret:
  1069  			checkKey("read", &tc.wsecrets, e)
  1070  			switch e.Level {
  1071  			case tls.QUICEncryptionLevelHandshake:
  1072  				tc.keysHandshake.w.init(e.Suite, e.Data)
  1073  			case tls.QUICEncryptionLevelApplication:
  1074  				setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
  1075  			}
  1076  		case tls.QUICWriteData:
  1077  			tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)
  1078  		case tls.QUICTransportParameters:
  1079  			p, err := unmarshalTransportParams(e.Data)
  1080  			if err != nil {
  1081  				tc.t.Logf("sent unparseable transport parameters %x %v", e.Data, err)
  1082  			} else {
  1083  				tc.sentTransportParameters = &p
  1084  			}
  1085  		}
  1086  	}
  1087  }
  1088  
  1089  // nextMessage is called by the Conn's event loop to request its next event.
  1090  func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) {
  1091  	tc.timer = timer
  1092  	for {
  1093  		if !timer.IsZero() && !timer.After(tc.endpoint.now) {
  1094  			if timer.Equal(tc.timerLastFired) {
  1095  				// If the connection timer fires at time T, the Conn should take some
  1096  				// action to advance the timer into the future. If the Conn reschedules
  1097  				// the timer for the same time, it isn't making progress and we have a bug.
  1098  				tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer)
  1099  			} else {
  1100  				tc.timerLastFired = timer
  1101  				return tc.endpoint.now, timerEvent{}
  1102  			}
  1103  		}
  1104  		select {
  1105  		case m := <-msgc:
  1106  			return tc.endpoint.now, m
  1107  		default:
  1108  		}
  1109  		if !tc.wakeAsync() {
  1110  			break
  1111  		}
  1112  	}
  1113  	// If the message queue is empty, then the conn is idle.
  1114  	if tc.idlec != nil {
  1115  		idlec := tc.idlec
  1116  		tc.idlec = nil
  1117  		close(idlec)
  1118  	}
  1119  	m = <-msgc
  1120  	return tc.endpoint.now, m
  1121  }
  1122  
  1123  func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) {
  1124  	return testLocalConnID(seq), nil
  1125  }
  1126  
  1127  func (tc *testConnHooks) timeNow() time.Time {
  1128  	return tc.endpoint.now
  1129  }
  1130  
  1131  // testLocalConnID returns the connection ID with a given sequence number
  1132  // used by a Conn under test.
  1133  func testLocalConnID(seq int64) []byte {
  1134  	cid := make([]byte, connIDLen)
  1135  	copy(cid, []byte{0xc0, 0xff, 0xee})
  1136  	cid[len(cid)-1] = byte(seq)
  1137  	return cid
  1138  }
  1139  
  1140  // testPeerConnID returns the connection ID with a given sequence number
  1141  // used by the fake peer of a Conn under test.
  1142  func testPeerConnID(seq int64) []byte {
  1143  	// Use a different length than we choose for our own conn ids,
  1144  	// to help catch any bad assumptions.
  1145  	return []byte{0xbe, 0xee, 0xff, byte(seq)}
  1146  }
  1147  
  1148  func testPeerStatelessResetToken(seq int64) statelessResetToken {
  1149  	return statelessResetToken{
  1150  		0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee,
  1151  		0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, byte(seq),
  1152  	}
  1153  }
  1154  
  1155  // canceledContext returns a canceled Context.
  1156  //
  1157  // Functions which take a context preference progress over cancelation.
  1158  // For example, a read with a canceled context will return data if any is available.
  1159  // Tests use canceled contexts to perform non-blocking operations.
  1160  func canceledContext() context.Context {
  1161  	ctx, cancel := context.WithCancel(context.Background())
  1162  	cancel()
  1163  	return ctx
  1164  }