golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/quic/conn_test.go (about) 1 // Copyright 2023 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 //go:build go1.21 6 7 package quic 8 9 import ( 10 "bytes" 11 "context" 12 "crypto/tls" 13 "errors" 14 "flag" 15 "fmt" 16 "log/slog" 17 "math" 18 "net/netip" 19 "reflect" 20 "strings" 21 "testing" 22 "time" 23 24 "golang.org/x/net/quic/qlog" 25 ) 26 27 var ( 28 testVV = flag.Bool("vv", false, "even more verbose test output") 29 qlogdir = flag.String("qlog", "", "write qlog logs to directory") 30 ) 31 32 func TestConnTestConn(t *testing.T) { 33 tc := newTestConn(t, serverSide) 34 tc.handshake() 35 if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want { 36 t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want) 37 } 38 39 ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) { 40 tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) { 41 when = now 42 }) 43 return 44 }).result() 45 if !ranAt.Equal(tc.endpoint.now) { 46 t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now) 47 } 48 tc.wait() 49 50 nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2) 51 tc.advanceTo(nextTime) 52 ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) { 53 tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) { 54 when = now 55 }) 56 return 57 }).result() 58 if !ranAt.Equal(nextTime) { 59 t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime) 60 } 61 tc.wait() 62 63 tc.advanceToTimer() 64 if got := tc.conn.lifetime.state; got != connStateDone { 65 t.Errorf("after advancing to idle timeout, conn state = %v, want done", got) 66 } 67 } 68 69 type testDatagram struct { 70 packets []*testPacket 71 paddedSize int 72 addr netip.AddrPort 73 } 74 75 func (d testDatagram) String() string { 76 var b strings.Builder 77 fmt.Fprintf(&b, "datagram with %v packets", len(d.packets)) 78 if d.paddedSize > 0 { 79 fmt.Fprintf(&b, " (padded to %v bytes)", d.paddedSize) 80 } 81 b.WriteString(":") 82 for _, p := range d.packets { 83 b.WriteString("\n") 84 b.WriteString(p.String()) 85 } 86 return b.String() 87 } 88 89 type testPacket struct { 90 ptype packetType 91 header byte 92 version uint32 93 num packetNumber 94 keyPhaseBit bool 95 keyNumber int 96 dstConnID []byte 97 srcConnID []byte 98 token []byte 99 originalDstConnID []byte // used for encoding Retry packets 100 frames []debugFrame 101 } 102 103 func (p testPacket) String() string { 104 var b strings.Builder 105 fmt.Fprintf(&b, " %v %v", p.ptype, p.num) 106 if p.version != 0 { 107 fmt.Fprintf(&b, " version=%v", p.version) 108 } 109 if p.srcConnID != nil { 110 fmt.Fprintf(&b, " src={%x}", p.srcConnID) 111 } 112 if p.dstConnID != nil { 113 fmt.Fprintf(&b, " dst={%x}", p.dstConnID) 114 } 115 if p.token != nil { 116 fmt.Fprintf(&b, " token={%x}", p.token) 117 } 118 for _, f := range p.frames { 119 fmt.Fprintf(&b, "\n %v", f) 120 } 121 return b.String() 122 } 123 124 // maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test. 125 const maxTestKeyPhases = 3 126 127 // A testConn is a Conn whose external interactions (sending and receiving packets, 128 // setting timers) can be manipulated in tests. 129 type testConn struct { 130 t *testing.T 131 conn *Conn 132 endpoint *testEndpoint 133 timer time.Time 134 timerLastFired time.Time 135 idlec chan struct{} // only accessed on the conn's loop 136 137 // Keys are distinct from the conn's keys, 138 // because the test may know about keys before the conn does. 139 // For example, when sending a datagram with coalesced 140 // Initial and Handshake packets to a client conn, 141 // we use Handshake keys to encrypt the packet. 142 // The client only acquires those keys when it processes 143 // the Initial packet. 144 keysInitial fixedKeyPair 145 keysHandshake fixedKeyPair 146 rkeyAppData test1RTTKeys 147 wkeyAppData test1RTTKeys 148 rsecrets [numberSpaceCount]keySecret 149 wsecrets [numberSpaceCount]keySecret 150 151 // testConn uses a test hook to snoop on the conn's TLS events. 152 // CRYPTO data produced by the conn's QUICConn is placed in 153 // cryptoDataOut. 154 // 155 // The peerTLSConn is is a QUICConn representing the peer. 156 // CRYPTO data produced by the conn is written to peerTLSConn, 157 // and data produced by peerTLSConn is placed in cryptoDataIn. 158 cryptoDataOut map[tls.QUICEncryptionLevel][]byte 159 cryptoDataIn map[tls.QUICEncryptionLevel][]byte 160 peerTLSConn *tls.QUICConn 161 162 // Information about the conn's (fake) peer. 163 peerConnID []byte // source conn id of peer's packets 164 peerNextPacketNum [numberSpaceCount]packetNumber // next packet number to use 165 166 // Datagrams, packets, and frames sent by the conn, 167 // but not yet processed by the test. 168 sentDatagrams [][]byte 169 sentPackets []*testPacket 170 sentFrames []debugFrame 171 lastDatagram *testDatagram 172 lastPacket *testPacket 173 174 recvDatagram chan *datagram 175 176 // Transport parameters sent by the conn. 177 sentTransportParameters *transportParameters 178 179 // Frame types to ignore in tests. 180 ignoreFrames map[byte]bool 181 182 // Values to set in packets sent to the conn. 183 sendKeyNumber int 184 sendKeyPhaseBit bool 185 186 asyncTestState 187 } 188 189 type test1RTTKeys struct { 190 hdr headerKey 191 pkt [maxTestKeyPhases]packetKey 192 } 193 194 type keySecret struct { 195 suite uint16 196 secret []byte 197 } 198 199 // newTestConn creates a Conn for testing. 200 // 201 // The Conn's event loop is controlled by the test, 202 // allowing test code to access Conn state directly 203 // by first ensuring the loop goroutine is idle. 204 func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { 205 t.Helper() 206 config := &Config{ 207 TLSConfig: newTestTLSConfig(side), 208 StatelessResetKey: testStatelessResetKey, 209 QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{ 210 Level: QLogLevelFrame, 211 Dir: *qlogdir, 212 })), 213 } 214 var cids newServerConnIDs 215 if side == serverSide { 216 // The initial connection ID for the server is chosen by the client. 217 cids.srcConnID = testPeerConnID(0) 218 cids.dstConnID = testPeerConnID(-1) 219 cids.originalDstConnID = cids.dstConnID 220 } 221 var configTransportParams []func(*transportParameters) 222 var configTestConn []func(*testConn) 223 for _, o := range opts { 224 switch o := o.(type) { 225 case func(*Config): 226 o(config) 227 case func(*tls.Config): 228 o(config.TLSConfig) 229 case func(cids *newServerConnIDs): 230 o(&cids) 231 case func(p *transportParameters): 232 configTransportParams = append(configTransportParams, o) 233 case func(p *testConn): 234 configTestConn = append(configTestConn, o) 235 default: 236 t.Fatalf("unknown newTestConn option %T", o) 237 } 238 } 239 240 endpoint := newTestEndpoint(t, config) 241 endpoint.configTransportParams = configTransportParams 242 endpoint.configTestConn = configTestConn 243 conn, err := endpoint.e.newConn( 244 endpoint.now, 245 config, 246 side, 247 cids, 248 "", 249 netip.MustParseAddrPort("127.0.0.1:443")) 250 if err != nil { 251 t.Fatal(err) 252 } 253 tc := endpoint.conns[conn] 254 tc.wait() 255 return tc 256 } 257 258 func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn { 259 t.Helper() 260 tc := &testConn{ 261 t: t, 262 endpoint: endpoint, 263 conn: conn, 264 peerConnID: testPeerConnID(0), 265 ignoreFrames: map[byte]bool{ 266 frameTypePadding: true, // ignore PADDING by default 267 }, 268 cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte), 269 cryptoDataIn: make(map[tls.QUICEncryptionLevel][]byte), 270 recvDatagram: make(chan *datagram), 271 } 272 t.Cleanup(tc.cleanup) 273 for _, f := range endpoint.configTestConn { 274 f(tc) 275 } 276 conn.testHooks = (*testConnHooks)(tc) 277 278 if endpoint.peerTLSConn != nil { 279 tc.peerTLSConn = endpoint.peerTLSConn 280 endpoint.peerTLSConn = nil 281 return tc 282 } 283 284 peerProvidedParams := defaultTransportParameters() 285 peerProvidedParams.initialSrcConnID = testPeerConnID(0) 286 if conn.side == clientSide { 287 peerProvidedParams.originalDstConnID = testLocalConnID(-1) 288 } 289 for _, f := range endpoint.configTransportParams { 290 f(&peerProvidedParams) 291 } 292 293 peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(conn.side.peer())} 294 if conn.side == clientSide { 295 tc.peerTLSConn = tls.QUICServer(peerQUICConfig) 296 } else { 297 tc.peerTLSConn = tls.QUICClient(peerQUICConfig) 298 } 299 tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams)) 300 tc.peerTLSConn.Start(context.Background()) 301 t.Cleanup(func() { 302 tc.peerTLSConn.Close() 303 }) 304 305 return tc 306 } 307 308 // advance causes time to pass. 309 func (tc *testConn) advance(d time.Duration) { 310 tc.t.Helper() 311 tc.endpoint.advance(d) 312 } 313 314 // advanceTo sets the current time. 315 func (tc *testConn) advanceTo(now time.Time) { 316 tc.t.Helper() 317 tc.endpoint.advanceTo(now) 318 } 319 320 // advanceToTimer sets the current time to the time of the Conn's next timer event. 321 func (tc *testConn) advanceToTimer() { 322 if tc.timer.IsZero() { 323 tc.t.Fatalf("advancing to timer, but timer is not set") 324 } 325 tc.advanceTo(tc.timer) 326 } 327 328 func (tc *testConn) timerDelay() time.Duration { 329 if tc.timer.IsZero() { 330 return math.MaxInt64 // infinite 331 } 332 if tc.timer.Before(tc.endpoint.now) { 333 return 0 334 } 335 return tc.timer.Sub(tc.endpoint.now) 336 } 337 338 const infiniteDuration = time.Duration(math.MaxInt64) 339 340 // timeUntilEvent returns the amount of time until the next connection event. 341 func (tc *testConn) timeUntilEvent() time.Duration { 342 if tc.timer.IsZero() { 343 return infiniteDuration 344 } 345 if tc.timer.Before(tc.endpoint.now) { 346 return 0 347 } 348 return tc.timer.Sub(tc.endpoint.now) 349 } 350 351 // wait blocks until the conn becomes idle. 352 // The conn is idle when it is blocked waiting for a packet to arrive or a timer to expire. 353 // Tests shouldn't need to call wait directly. 354 // testConn methods that wake the Conn event loop will call wait for them. 355 func (tc *testConn) wait() { 356 tc.t.Helper() 357 idlec := make(chan struct{}) 358 fail := false 359 tc.conn.sendMsg(func(now time.Time, c *Conn) { 360 if tc.idlec != nil { 361 tc.t.Errorf("testConn.wait called concurrently") 362 fail = true 363 close(idlec) 364 } else { 365 // nextMessage will close idlec. 366 tc.idlec = idlec 367 } 368 }) 369 select { 370 case <-idlec: 371 case <-tc.conn.donec: 372 // We may have async ops that can proceed now that the conn is done. 373 tc.wakeAsync() 374 } 375 if fail { 376 panic(fail) 377 } 378 } 379 380 func (tc *testConn) cleanup() { 381 if tc.conn == nil { 382 return 383 } 384 tc.conn.exit() 385 <-tc.conn.donec 386 } 387 388 func (tc *testConn) acceptStream() *Stream { 389 tc.t.Helper() 390 s, err := tc.conn.AcceptStream(canceledContext()) 391 if err != nil { 392 tc.t.Fatalf("conn.AcceptStream() = %v, want stream", err) 393 } 394 s.SetReadContext(canceledContext()) 395 s.SetWriteContext(canceledContext()) 396 return s 397 } 398 399 func logDatagram(t *testing.T, text string, d *testDatagram) { 400 t.Helper() 401 if !*testVV { 402 return 403 } 404 pad := "" 405 if d.paddedSize > 0 { 406 pad = fmt.Sprintf(" (padded to %v)", d.paddedSize) 407 } 408 t.Logf("%v datagram%v", text, pad) 409 for _, p := range d.packets { 410 var s string 411 switch p.ptype { 412 case packetType1RTT: 413 s = fmt.Sprintf(" %v pnum=%v", p.ptype, p.num) 414 default: 415 s = fmt.Sprintf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID) 416 } 417 if p.token != nil { 418 s += fmt.Sprintf(" token={%x}", p.token) 419 } 420 if p.keyPhaseBit { 421 s += fmt.Sprintf(" KeyPhase") 422 } 423 if p.keyNumber != 0 { 424 s += fmt.Sprintf(" keynum=%v", p.keyNumber) 425 } 426 t.Log(s) 427 for _, f := range p.frames { 428 t.Logf(" %v", f) 429 } 430 } 431 } 432 433 // write sends the Conn a datagram. 434 func (tc *testConn) write(d *testDatagram) { 435 tc.t.Helper() 436 tc.endpoint.writeDatagram(d) 437 } 438 439 // writeFrame sends the Conn a datagram containing the given frames. 440 func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) { 441 tc.t.Helper() 442 space := spaceForPacketType(ptype) 443 dstConnID := tc.conn.connIDState.local[0].cid 444 if tc.conn.connIDState.local[0].seq == -1 && ptype != packetTypeInitial { 445 // Only use the transient connection ID in Initial packets. 446 dstConnID = tc.conn.connIDState.local[1].cid 447 } 448 d := &testDatagram{ 449 packets: []*testPacket{{ 450 ptype: ptype, 451 num: tc.peerNextPacketNum[space], 452 keyNumber: tc.sendKeyNumber, 453 keyPhaseBit: tc.sendKeyPhaseBit, 454 frames: frames, 455 version: quicVersion1, 456 dstConnID: dstConnID, 457 srcConnID: tc.peerConnID, 458 }}, 459 addr: tc.conn.peerAddr, 460 } 461 if ptype == packetTypeInitial && tc.conn.side == serverSide { 462 d.paddedSize = 1200 463 } 464 tc.write(d) 465 } 466 467 // writeAckForAll sends the Conn a datagram containing an ack for all packets up to the 468 // last one received. 469 func (tc *testConn) writeAckForAll() { 470 tc.t.Helper() 471 if tc.lastPacket == nil { 472 return 473 } 474 tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{ 475 ranges: []i64range[packetNumber]{{0, tc.lastPacket.num + 1}}, 476 }) 477 } 478 479 // writeAckForLatest sends the Conn a datagram containing an ack for the 480 // most recent packet received. 481 func (tc *testConn) writeAckForLatest() { 482 tc.t.Helper() 483 if tc.lastPacket == nil { 484 return 485 } 486 tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{ 487 ranges: []i64range[packetNumber]{{tc.lastPacket.num, tc.lastPacket.num + 1}}, 488 }) 489 } 490 491 // ignoreFrame hides frames of the given type sent by the Conn. 492 func (tc *testConn) ignoreFrame(frameType byte) { 493 tc.ignoreFrames[frameType] = true 494 } 495 496 // readDatagram reads the next datagram sent by the Conn. 497 // It returns nil if the Conn has no more datagrams to send at this time. 498 func (tc *testConn) readDatagram() *testDatagram { 499 tc.t.Helper() 500 tc.wait() 501 tc.sentPackets = nil 502 tc.sentFrames = nil 503 buf := tc.endpoint.read() 504 if buf == nil { 505 return nil 506 } 507 d := parseTestDatagram(tc.t, tc.endpoint, tc, buf) 508 // Log the datagram before removing ignored frames. 509 // When things go wrong, it's useful to see all the frames. 510 logDatagram(tc.t, "-> conn under test sends", d) 511 typeForFrame := func(f debugFrame) byte { 512 // This is very clunky, and points at a problem 513 // in how we specify what frames to ignore in tests. 514 // 515 // We mark frames to ignore using the frame type, 516 // but we've got a debugFrame data structure here. 517 // Perhaps we should be ignoring frames by debugFrame 518 // type instead: tc.ignoreFrame[debugFrameAck](). 519 switch f := f.(type) { 520 case debugFramePadding: 521 return frameTypePadding 522 case debugFramePing: 523 return frameTypePing 524 case debugFrameAck: 525 return frameTypeAck 526 case debugFrameResetStream: 527 return frameTypeResetStream 528 case debugFrameStopSending: 529 return frameTypeStopSending 530 case debugFrameCrypto: 531 return frameTypeCrypto 532 case debugFrameNewToken: 533 return frameTypeNewToken 534 case debugFrameStream: 535 return frameTypeStreamBase 536 case debugFrameMaxData: 537 return frameTypeMaxData 538 case debugFrameMaxStreamData: 539 return frameTypeMaxStreamData 540 case debugFrameMaxStreams: 541 if f.streamType == bidiStream { 542 return frameTypeMaxStreamsBidi 543 } else { 544 return frameTypeMaxStreamsUni 545 } 546 case debugFrameDataBlocked: 547 return frameTypeDataBlocked 548 case debugFrameStreamDataBlocked: 549 return frameTypeStreamDataBlocked 550 case debugFrameStreamsBlocked: 551 if f.streamType == bidiStream { 552 return frameTypeStreamsBlockedBidi 553 } else { 554 return frameTypeStreamsBlockedUni 555 } 556 case debugFrameNewConnectionID: 557 return frameTypeNewConnectionID 558 case debugFrameRetireConnectionID: 559 return frameTypeRetireConnectionID 560 case debugFramePathChallenge: 561 return frameTypePathChallenge 562 case debugFramePathResponse: 563 return frameTypePathResponse 564 case debugFrameConnectionCloseTransport: 565 return frameTypeConnectionCloseTransport 566 case debugFrameConnectionCloseApplication: 567 return frameTypeConnectionCloseApplication 568 case debugFrameHandshakeDone: 569 return frameTypeHandshakeDone 570 } 571 panic(fmt.Errorf("unhandled frame type %T", f)) 572 } 573 for _, p := range d.packets { 574 var frames []debugFrame 575 for _, f := range p.frames { 576 if !tc.ignoreFrames[typeForFrame(f)] { 577 frames = append(frames, f) 578 } 579 } 580 p.frames = frames 581 } 582 tc.lastDatagram = d 583 return d 584 } 585 586 // readPacket reads the next packet sent by the Conn. 587 // It returns nil if the Conn has no more packets to send at this time. 588 func (tc *testConn) readPacket() *testPacket { 589 tc.t.Helper() 590 for len(tc.sentPackets) == 0 { 591 d := tc.readDatagram() 592 if d == nil { 593 return nil 594 } 595 for _, p := range d.packets { 596 if len(p.frames) == 0 { 597 tc.lastPacket = p 598 continue 599 } 600 tc.sentPackets = append(tc.sentPackets, p) 601 } 602 } 603 p := tc.sentPackets[0] 604 tc.sentPackets = tc.sentPackets[1:] 605 tc.lastPacket = p 606 return p 607 } 608 609 // readFrame reads the next frame sent by the Conn. 610 // It returns nil if the Conn has no more frames to send at this time. 611 func (tc *testConn) readFrame() (debugFrame, packetType) { 612 tc.t.Helper() 613 for len(tc.sentFrames) == 0 { 614 p := tc.readPacket() 615 if p == nil { 616 return nil, packetTypeInvalid 617 } 618 tc.sentFrames = p.frames 619 } 620 f := tc.sentFrames[0] 621 tc.sentFrames = tc.sentFrames[1:] 622 return f, tc.lastPacket.ptype 623 } 624 625 // wantDatagram indicates that we expect the Conn to send a datagram. 626 func (tc *testConn) wantDatagram(expectation string, want *testDatagram) { 627 tc.t.Helper() 628 got := tc.readDatagram() 629 if !datagramEqual(got, want) { 630 tc.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want) 631 } 632 } 633 634 func datagramEqual(a, b *testDatagram) bool { 635 if a == nil && b == nil { 636 return true 637 } 638 if a == nil || b == nil { 639 return false 640 } 641 if a.paddedSize != b.paddedSize || 642 a.addr != b.addr || 643 len(a.packets) != len(b.packets) { 644 return false 645 } 646 for i := range a.packets { 647 if !packetEqual(a.packets[i], b.packets[i]) { 648 return false 649 } 650 } 651 return true 652 } 653 654 // wantPacket indicates that we expect the Conn to send a packet. 655 func (tc *testConn) wantPacket(expectation string, want *testPacket) { 656 tc.t.Helper() 657 got := tc.readPacket() 658 if !packetEqual(got, want) { 659 tc.t.Fatalf("%v:\ngot packet: %v\nwant packet: %v", expectation, got, want) 660 } 661 } 662 663 func packetEqual(a, b *testPacket) bool { 664 if a == nil && b == nil { 665 return true 666 } 667 if a == nil || b == nil { 668 return false 669 } 670 ac := *a 671 ac.frames = nil 672 ac.header = 0 673 bc := *b 674 bc.frames = nil 675 bc.header = 0 676 if !reflect.DeepEqual(ac, bc) { 677 return false 678 } 679 if len(a.frames) != len(b.frames) { 680 return false 681 } 682 for i := range a.frames { 683 if !frameEqual(a.frames[i], b.frames[i]) { 684 return false 685 } 686 } 687 return true 688 } 689 690 // wantFrame indicates that we expect the Conn to send a frame. 691 func (tc *testConn) wantFrame(expectation string, wantType packetType, want debugFrame) { 692 tc.t.Helper() 693 got, gotType := tc.readFrame() 694 if got == nil { 695 tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want) 696 } 697 if gotType != wantType { 698 tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got) 699 } 700 if !frameEqual(got, want) { 701 tc.t.Fatalf("%v:\ngot frame: %v\nwant frame: %v", expectation, got, want) 702 } 703 } 704 705 func frameEqual(a, b debugFrame) bool { 706 switch af := a.(type) { 707 case debugFrameConnectionCloseTransport: 708 bf, ok := b.(debugFrameConnectionCloseTransport) 709 return ok && af.code == bf.code 710 } 711 return reflect.DeepEqual(a, b) 712 } 713 714 // wantFrameType indicates that we expect the Conn to send a frame, 715 // although we don't care about the contents. 716 func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) { 717 tc.t.Helper() 718 got, gotType := tc.readFrame() 719 if got == nil { 720 tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want) 721 } 722 if gotType != wantType { 723 tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got) 724 } 725 if reflect.TypeOf(got) != reflect.TypeOf(want) { 726 tc.t.Fatalf("%v:\ngot frame: %v\nwant frame of type: %v", expectation, got, want) 727 } 728 } 729 730 // wantIdle indicates that we expect the Conn to not send any more frames. 731 func (tc *testConn) wantIdle(expectation string) { 732 tc.t.Helper() 733 switch { 734 case len(tc.sentFrames) > 0: 735 tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentFrames[0]) 736 case len(tc.sentPackets) > 0: 737 tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentPackets[0]) 738 } 739 if f, _ := tc.readFrame(); f != nil { 740 tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, f) 741 } 742 } 743 744 func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte { 745 t.Helper() 746 var w packetWriter 747 w.reset(1200) 748 var pnumMaxAcked packetNumber 749 switch p.ptype { 750 case packetTypeRetry: 751 return encodeRetryPacket(p.originalDstConnID, retryPacket{ 752 srcConnID: p.srcConnID, 753 dstConnID: p.dstConnID, 754 token: p.token, 755 }) 756 case packetType1RTT: 757 w.start1RTTPacket(p.num, pnumMaxAcked, p.dstConnID) 758 default: 759 w.startProtectedLongHeaderPacket(pnumMaxAcked, longPacket{ 760 ptype: p.ptype, 761 version: p.version, 762 num: p.num, 763 dstConnID: p.dstConnID, 764 srcConnID: p.srcConnID, 765 extra: p.token, 766 }) 767 } 768 for _, f := range p.frames { 769 f.write(&w) 770 } 771 w.appendPaddingTo(pad) 772 if p.ptype != packetType1RTT { 773 var k fixedKeys 774 if tc == nil { 775 if p.ptype == packetTypeInitial { 776 k = initialKeys(p.dstConnID, serverSide).r 777 } else { 778 t.Fatalf("sending %v packet with no conn", p.ptype) 779 } 780 } else { 781 switch p.ptype { 782 case packetTypeInitial: 783 k = tc.keysInitial.w 784 case packetTypeHandshake: 785 k = tc.keysHandshake.w 786 } 787 } 788 if !k.isSet() { 789 t.Fatalf("sending %v packet with no write key", p.ptype) 790 } 791 w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{ 792 ptype: p.ptype, 793 version: p.version, 794 num: p.num, 795 dstConnID: p.dstConnID, 796 srcConnID: p.srcConnID, 797 extra: p.token, 798 }) 799 } else { 800 if tc == nil || !tc.wkeyAppData.hdr.isSet() { 801 t.Fatalf("sending 1-RTT packet with no write key") 802 } 803 // Somewhat hackish: Generate a temporary updatingKeyPair that will 804 // always use our desired key phase. 805 k := &updatingKeyPair{ 806 w: updatingKeys{ 807 hdr: tc.wkeyAppData.hdr, 808 pkt: [2]packetKey{ 809 tc.wkeyAppData.pkt[p.keyNumber], 810 tc.wkeyAppData.pkt[p.keyNumber], 811 }, 812 }, 813 updateAfter: maxPacketNumber, 814 } 815 if p.keyPhaseBit { 816 k.phase |= keyPhaseBit 817 } 818 w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k) 819 } 820 return w.datagram() 821 } 822 823 func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram { 824 t.Helper() 825 bufSize := len(buf) 826 d := &testDatagram{} 827 size := len(buf) 828 for len(buf) > 0 { 829 if buf[0] == 0 { 830 d.paddedSize = bufSize 831 break 832 } 833 ptype := getPacketType(buf) 834 switch ptype { 835 case packetTypeRetry: 836 retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID) 837 if !ok { 838 t.Fatalf("could not parse %v packet", ptype) 839 } 840 return &testDatagram{ 841 packets: []*testPacket{{ 842 ptype: packetTypeRetry, 843 dstConnID: retry.dstConnID, 844 srcConnID: retry.srcConnID, 845 token: retry.token, 846 }}, 847 } 848 case packetTypeInitial, packetTypeHandshake: 849 var k fixedKeys 850 if tc == nil { 851 if ptype == packetTypeInitial { 852 p, _ := parseGenericLongHeaderPacket(buf) 853 k = initialKeys(p.srcConnID, serverSide).w 854 } else { 855 t.Fatalf("reading %v packet with no conn", ptype) 856 } 857 } else { 858 switch ptype { 859 case packetTypeInitial: 860 k = tc.keysInitial.r 861 case packetTypeHandshake: 862 k = tc.keysHandshake.r 863 } 864 } 865 if !k.isSet() { 866 t.Fatalf("reading %v packet with no read key", ptype) 867 } 868 var pnumMax packetNumber // TODO: Track packet numbers. 869 p, n := parseLongHeaderPacket(buf, k, pnumMax) 870 if n < 0 { 871 t.Fatalf("packet parse error") 872 } 873 frames, err := parseTestFrames(t, p.payload) 874 if err != nil { 875 t.Fatal(err) 876 } 877 var token []byte 878 if ptype == packetTypeInitial && len(p.extra) > 0 { 879 token = p.extra 880 } 881 d.packets = append(d.packets, &testPacket{ 882 ptype: p.ptype, 883 header: buf[0], 884 version: p.version, 885 num: p.num, 886 dstConnID: p.dstConnID, 887 srcConnID: p.srcConnID, 888 token: token, 889 frames: frames, 890 }) 891 buf = buf[n:] 892 case packetType1RTT: 893 if tc == nil || !tc.rkeyAppData.hdr.isSet() { 894 t.Fatalf("reading 1-RTT packet with no read key") 895 } 896 var pnumMax packetNumber // TODO: Track packet numbers. 897 pnumOff := 1 + len(tc.peerConnID) 898 // Try unprotecting the packet with the first maxTestKeyPhases keys. 899 var phase int 900 var pnum packetNumber 901 var hdr []byte 902 var pay []byte 903 var err error 904 for phase = 0; phase < maxTestKeyPhases; phase++ { 905 b := append([]byte{}, buf...) 906 hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax) 907 if err != nil { 908 t.Fatalf("1-RTT packet header parse error") 909 } 910 k := tc.rkeyAppData.pkt[phase] 911 pay, err = k.unprotect(hdr, pay, pnum) 912 if err == nil { 913 break 914 } 915 } 916 if err != nil { 917 t.Fatalf("1-RTT packet payload parse error") 918 } 919 frames, err := parseTestFrames(t, pay) 920 if err != nil { 921 t.Fatal(err) 922 } 923 d.packets = append(d.packets, &testPacket{ 924 ptype: packetType1RTT, 925 header: hdr[0], 926 num: pnum, 927 dstConnID: hdr[1:][:len(tc.peerConnID)], 928 keyPhaseBit: hdr[0]&keyPhaseBit != 0, 929 keyNumber: phase, 930 frames: frames, 931 }) 932 buf = buf[len(buf):] 933 default: 934 t.Fatalf("unhandled packet type %v", ptype) 935 } 936 } 937 // This is rather hackish: If the last frame in the last packet 938 // in the datagram is PADDING, then remove it and record 939 // the padded size in the testDatagram.paddedSize. 940 // 941 // This makes it easier to write a test that expects a datagram 942 // padded to 1200 bytes. 943 if len(d.packets) > 0 && len(d.packets[len(d.packets)-1].frames) > 0 { 944 p := d.packets[len(d.packets)-1] 945 f := p.frames[len(p.frames)-1] 946 if _, ok := f.(debugFramePadding); ok { 947 p.frames = p.frames[:len(p.frames)-1] 948 d.paddedSize = size 949 } 950 } 951 return d 952 } 953 954 func parseTestFrames(t *testing.T, payload []byte) ([]debugFrame, error) { 955 t.Helper() 956 var frames []debugFrame 957 for len(payload) > 0 { 958 f, n := parseDebugFrame(payload) 959 if n < 0 { 960 return nil, errors.New("error parsing frames") 961 } 962 frames = append(frames, f) 963 payload = payload[n:] 964 } 965 return frames, nil 966 } 967 968 func spaceForPacketType(ptype packetType) numberSpace { 969 switch ptype { 970 case packetTypeInitial: 971 return initialSpace 972 case packetType0RTT: 973 panic("TODO: packetType0RTT") 974 case packetTypeHandshake: 975 return handshakeSpace 976 case packetTypeRetry: 977 panic("retry packets have no number space") 978 case packetType1RTT: 979 return appDataSpace 980 } 981 panic("unknown packet type") 982 } 983 984 // testConnHooks implements connTestHooks. 985 type testConnHooks testConn 986 987 func (tc *testConnHooks) init() { 988 tc.conn.keysAppData.updateAfter = maxPacketNumber // disable key updates 989 tc.keysInitial.r = tc.conn.keysInitial.w 990 tc.keysInitial.w = tc.conn.keysInitial.r 991 if tc.conn.side == serverSide { 992 tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc)) 993 } 994 } 995 996 // handleTLSEvent processes TLS events generated by 997 // the connection under test's tls.QUICConn. 998 // 999 // We maintain a second tls.QUICConn representing the peer, 1000 // and feed the TLS handshake data into it. 1001 // 1002 // We stash TLS handshake data from both sides in the testConn, 1003 // where it can be used by tests. 1004 // 1005 // We snoop packet protection keys out of the tls.QUICConns, 1006 // and verify that both sides of the connection are getting 1007 // matching keys. 1008 func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) { 1009 checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) { 1010 var space numberSpace 1011 switch { 1012 case e.Level == tls.QUICEncryptionLevelHandshake: 1013 space = handshakeSpace 1014 case e.Level == tls.QUICEncryptionLevelApplication: 1015 space = appDataSpace 1016 default: 1017 tc.t.Errorf("unexpected encryption level %v", e.Level) 1018 return 1019 } 1020 if secrets[space].secret == nil { 1021 secrets[space].suite = e.Suite 1022 secrets[space].secret = append([]byte{}, e.Data...) 1023 } else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) { 1024 tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level) 1025 } 1026 } 1027 setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) { 1028 k.hdr.init(suite, secret) 1029 for i := 0; i < len(k.pkt); i++ { 1030 k.pkt[i].init(suite, secret) 1031 secret = updateSecret(suite, secret) 1032 } 1033 } 1034 switch e.Kind { 1035 case tls.QUICSetReadSecret: 1036 checkKey("write", &tc.wsecrets, e) 1037 switch e.Level { 1038 case tls.QUICEncryptionLevelHandshake: 1039 tc.keysHandshake.w.init(e.Suite, e.Data) 1040 case tls.QUICEncryptionLevelApplication: 1041 setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData) 1042 } 1043 case tls.QUICSetWriteSecret: 1044 checkKey("read", &tc.rsecrets, e) 1045 switch e.Level { 1046 case tls.QUICEncryptionLevelHandshake: 1047 tc.keysHandshake.r.init(e.Suite, e.Data) 1048 case tls.QUICEncryptionLevelApplication: 1049 setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData) 1050 } 1051 case tls.QUICWriteData: 1052 tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...) 1053 tc.peerTLSConn.HandleData(e.Level, e.Data) 1054 } 1055 for { 1056 e := tc.peerTLSConn.NextEvent() 1057 switch e.Kind { 1058 case tls.QUICNoEvent: 1059 return 1060 case tls.QUICSetReadSecret: 1061 checkKey("write", &tc.rsecrets, e) 1062 switch e.Level { 1063 case tls.QUICEncryptionLevelHandshake: 1064 tc.keysHandshake.r.init(e.Suite, e.Data) 1065 case tls.QUICEncryptionLevelApplication: 1066 setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData) 1067 } 1068 case tls.QUICSetWriteSecret: 1069 checkKey("read", &tc.wsecrets, e) 1070 switch e.Level { 1071 case tls.QUICEncryptionLevelHandshake: 1072 tc.keysHandshake.w.init(e.Suite, e.Data) 1073 case tls.QUICEncryptionLevelApplication: 1074 setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData) 1075 } 1076 case tls.QUICWriteData: 1077 tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...) 1078 case tls.QUICTransportParameters: 1079 p, err := unmarshalTransportParams(e.Data) 1080 if err != nil { 1081 tc.t.Logf("sent unparseable transport parameters %x %v", e.Data, err) 1082 } else { 1083 tc.sentTransportParameters = &p 1084 } 1085 } 1086 } 1087 } 1088 1089 // nextMessage is called by the Conn's event loop to request its next event. 1090 func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) { 1091 tc.timer = timer 1092 for { 1093 if !timer.IsZero() && !timer.After(tc.endpoint.now) { 1094 if timer.Equal(tc.timerLastFired) { 1095 // If the connection timer fires at time T, the Conn should take some 1096 // action to advance the timer into the future. If the Conn reschedules 1097 // the timer for the same time, it isn't making progress and we have a bug. 1098 tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer) 1099 } else { 1100 tc.timerLastFired = timer 1101 return tc.endpoint.now, timerEvent{} 1102 } 1103 } 1104 select { 1105 case m := <-msgc: 1106 return tc.endpoint.now, m 1107 default: 1108 } 1109 if !tc.wakeAsync() { 1110 break 1111 } 1112 } 1113 // If the message queue is empty, then the conn is idle. 1114 if tc.idlec != nil { 1115 idlec := tc.idlec 1116 tc.idlec = nil 1117 close(idlec) 1118 } 1119 m = <-msgc 1120 return tc.endpoint.now, m 1121 } 1122 1123 func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) { 1124 return testLocalConnID(seq), nil 1125 } 1126 1127 func (tc *testConnHooks) timeNow() time.Time { 1128 return tc.endpoint.now 1129 } 1130 1131 // testLocalConnID returns the connection ID with a given sequence number 1132 // used by a Conn under test. 1133 func testLocalConnID(seq int64) []byte { 1134 cid := make([]byte, connIDLen) 1135 copy(cid, []byte{0xc0, 0xff, 0xee}) 1136 cid[len(cid)-1] = byte(seq) 1137 return cid 1138 } 1139 1140 // testPeerConnID returns the connection ID with a given sequence number 1141 // used by the fake peer of a Conn under test. 1142 func testPeerConnID(seq int64) []byte { 1143 // Use a different length than we choose for our own conn ids, 1144 // to help catch any bad assumptions. 1145 return []byte{0xbe, 0xee, 0xff, byte(seq)} 1146 } 1147 1148 func testPeerStatelessResetToken(seq int64) statelessResetToken { 1149 return statelessResetToken{ 1150 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 1151 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, byte(seq), 1152 } 1153 } 1154 1155 // canceledContext returns a canceled Context. 1156 // 1157 // Functions which take a context preference progress over cancelation. 1158 // For example, a read with a canceled context will return data if any is available. 1159 // Tests use canceled contexts to perform non-blocking operations. 1160 func canceledContext() context.Context { 1161 ctx, cancel := context.WithCancel(context.Background()) 1162 cancel() 1163 return ctx 1164 }