github.com/tumi8/quic-go@v0.37.4-tum/integrationtests/self/key_update_test.go (about) 1 package self_test 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "net" 8 9 "github.com/tumi8/quic-go" 10 "github.com/tumi8/quic-go/noninternal/handshake" 11 "github.com/tumi8/quic-go/noninternal/protocol" 12 "github.com/tumi8/quic-go/logging" 13 14 . "github.com/onsi/ginkgo/v2" 15 . "github.com/onsi/gomega" 16 ) 17 18 var ( 19 sentHeaders []*logging.ShortHeader 20 receivedHeaders []*logging.ShortHeader 21 ) 22 23 func countKeyPhases() (sent, received int) { 24 lastKeyPhase := protocol.KeyPhaseOne 25 for _, hdr := range sentHeaders { 26 if hdr.KeyPhase != lastKeyPhase { 27 sent++ 28 lastKeyPhase = hdr.KeyPhase 29 } 30 } 31 lastKeyPhase = protocol.KeyPhaseOne 32 for _, hdr := range receivedHeaders { 33 if hdr.KeyPhase != lastKeyPhase { 34 received++ 35 lastKeyPhase = hdr.KeyPhase 36 } 37 } 38 return 39 } 40 41 type keyUpdateConnTracer struct { 42 logging.NullConnectionTracer 43 } 44 45 func (t *keyUpdateConnTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ *logging.AckFrame, _ []logging.Frame) { 46 sentHeaders = append(sentHeaders, hdr) 47 } 48 49 func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, frames []logging.Frame) { 50 receivedHeaders = append(receivedHeaders, hdr) 51 } 52 53 var _ = Describe("Key Update tests", func() { 54 It("downloads a large file", func() { 55 origKeyUpdateInterval := handshake.KeyUpdateInterval 56 defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }() 57 handshake.KeyUpdateInterval = 1 // update keys as frequently as possible 58 59 server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil) 60 Expect(err).ToNot(HaveOccurred()) 61 defer server.Close() 62 63 go func() { 64 defer GinkgoRecover() 65 conn, err := server.Accept(context.Background()) 66 Expect(err).ToNot(HaveOccurred()) 67 str, err := conn.OpenUniStream() 68 Expect(err).ToNot(HaveOccurred()) 69 defer str.Close() 70 _, err = str.Write(PRDataLong) 71 Expect(err).ToNot(HaveOccurred()) 72 }() 73 74 conn, err := quic.DialAddr( 75 context.Background(), 76 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 77 getTLSClientConfig(), 78 getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { 79 return &keyUpdateConnTracer{} 80 }}), 81 ) 82 Expect(err).ToNot(HaveOccurred()) 83 str, err := conn.AcceptUniStream(context.Background()) 84 Expect(err).ToNot(HaveOccurred()) 85 data, err := io.ReadAll(str) 86 Expect(err).ToNot(HaveOccurred()) 87 Expect(data).To(Equal(PRDataLong)) 88 Expect(conn.CloseWithError(0, "")).To(Succeed()) 89 90 keyPhasesSent, keyPhasesReceived := countKeyPhases() 91 fmt.Fprintf(GinkgoWriter, "Used %d key phases on outgoing and %d key phases on incoming packets.\n", keyPhasesSent, keyPhasesReceived) 92 Expect(keyPhasesReceived).To(BeNumerically(">", 10)) 93 Expect(keyPhasesReceived).To(BeNumerically("~", keyPhasesSent, 2)) 94 }) 95 })