github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/integrationtests/self/key_update_test.go (about) 1 package self_test 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "net" 8 9 "github.com/daeuniverse/quic-go" 10 "github.com/daeuniverse/quic-go/internal/handshake" 11 "github.com/daeuniverse/quic-go/internal/protocol" 12 "github.com/daeuniverse/quic-go/logging" 13 14 . "github.com/onsi/ginkgo/v2" 15 . "github.com/onsi/gomega" 16 ) 17 18 var ( 19 sentHeaders []*logging.ShortHeader 20 receivedHeaders []*logging.ShortHeader 21 ) 22 23 func countKeyPhases() (sent, received int) { 24 lastKeyPhase := protocol.KeyPhaseOne 25 for _, hdr := range sentHeaders { 26 if hdr.KeyPhase != lastKeyPhase { 27 sent++ 28 lastKeyPhase = hdr.KeyPhase 29 } 30 } 31 lastKeyPhase = protocol.KeyPhaseOne 32 for _, hdr := range receivedHeaders { 33 if hdr.KeyPhase != lastKeyPhase { 34 received++ 35 lastKeyPhase = hdr.KeyPhase 36 } 37 } 38 return 39 } 40 41 var keyUpdateConnTracer = &logging.ConnectionTracer{ 42 SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) { 43 sentHeaders = append(sentHeaders, hdr) 44 }, 45 ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) { 46 receivedHeaders = append(receivedHeaders, hdr) 47 }, 48 } 49 50 var _ = Describe("Key Update tests", func() { 51 It("downloads a large file", func() { 52 origKeyUpdateInterval := handshake.KeyUpdateInterval 53 defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }() 54 handshake.KeyUpdateInterval = 1 // update keys as frequently as possible 55 56 server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil) 57 Expect(err).ToNot(HaveOccurred()) 58 defer server.Close() 59 60 go func() { 61 defer GinkgoRecover() 62 conn, err := server.Accept(context.Background()) 63 Expect(err).ToNot(HaveOccurred()) 64 str, err := conn.OpenUniStream() 65 Expect(err).ToNot(HaveOccurred()) 66 defer str.Close() 67 _, err = str.Write(PRDataLong) 68 Expect(err).ToNot(HaveOccurred()) 69 }() 70 71 conn, err := quic.DialAddr( 72 context.Background(), 73 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 74 getTLSClientConfig(), 75 getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { 76 return keyUpdateConnTracer 77 }}), 78 ) 79 Expect(err).ToNot(HaveOccurred()) 80 str, err := conn.AcceptUniStream(context.Background()) 81 Expect(err).ToNot(HaveOccurred()) 82 data, err := io.ReadAll(str) 83 Expect(err).ToNot(HaveOccurred()) 84 Expect(data).To(Equal(PRDataLong)) 85 Expect(conn.CloseWithError(0, "")).To(Succeed()) 86 87 keyPhasesSent, keyPhasesReceived := countKeyPhases() 88 fmt.Fprintf(GinkgoWriter, "Used %d key phases on outgoing and %d key phases on incoming packets.\n", keyPhasesSent, keyPhasesReceived) 89 Expect(keyPhasesReceived).To(BeNumerically(">", 10)) 90 Expect(keyPhasesReceived).To(BeNumerically("~", keyPhasesSent, 2)) 91 }) 92 })