github.com/tumi8/quic-go@v0.37.4-tum/integrationtests/versionnegotiation/handshake_test.go (about) 1 package versionnegotiation 2 3 import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "net" 8 9 "github.com/tumi8/quic-go" 10 "github.com/tumi8/quic-go/integrationtests/tools/israce" 11 "github.com/tumi8/quic-go/noninternal/protocol" 12 "github.com/tumi8/quic-go/logging" 13 14 . "github.com/onsi/ginkgo/v2" 15 . "github.com/onsi/gomega" 16 ) 17 18 type versioner interface { 19 GetVersion() protocol.VersionNumber 20 } 21 22 type versionNegotiationTracer struct { 23 logging.NullConnectionTracer 24 25 loggedVersions bool 26 receivedVersionNegotiation bool 27 chosen logging.VersionNumber 28 clientVersions, serverVersions []logging.VersionNumber 29 } 30 31 var _ logging.ConnectionTracer = &versionNegotiationTracer{} 32 33 func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { 34 if t.loggedVersions { 35 Fail("only expected one call to NegotiatedVersions") 36 } 37 t.loggedVersions = true 38 t.chosen = chosen 39 t.clientVersions = clientVersions 40 t.serverVersions = serverVersions 41 } 42 43 func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { 44 t.receivedVersionNegotiation = true 45 } 46 47 var _ = Describe("Handshake tests", func() { 48 startServer := func(tlsConf *tls.Config, conf *quic.Config) (*quic.Listener, func()) { 49 server, err := quic.ListenAddr("localhost:0", tlsConf, conf) 50 Expect(err).ToNot(HaveOccurred()) 51 52 acceptStopped := make(chan struct{}) 53 go func() { 54 defer GinkgoRecover() 55 defer close(acceptStopped) 56 for { 57 if _, err := server.Accept(context.Background()); err != nil { 58 return 59 } 60 } 61 }() 62 63 return server, func() { 64 server.Close() 65 <-acceptStopped 66 } 67 } 68 69 var supportedVersions []protocol.VersionNumber 70 71 BeforeEach(func() { 72 supportedVersions = append([]quic.VersionNumber{}, protocol.SupportedVersions...) 73 protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{7, 8, 9, 10}...) 74 }) 75 76 AfterEach(func() { 77 protocol.SupportedVersions = supportedVersions 78 }) 79 80 if !israce.Enabled { 81 It("when the server supports more versions than the client", func() { 82 expectedVersion := protocol.SupportedVersions[0] 83 // the server doesn't support the highest supported version, which is the first one the client will try 84 // but it supports a bunch of versions that the client doesn't speak 85 serverConfig := &quic.Config{} 86 serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9} 87 serverTracer := &versionNegotiationTracer{} 88 serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { 89 return serverTracer 90 } 91 server, cl := startServer(getTLSConfig(), serverConfig) 92 defer cl() 93 clientTracer := &versionNegotiationTracer{} 94 conn, err := quic.DialAddr( 95 context.Background(), 96 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 97 getTLSClientConfig(), 98 maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer { 99 return clientTracer 100 }}), 101 ) 102 Expect(err).ToNot(HaveOccurred()) 103 Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion)) 104 Expect(conn.CloseWithError(0, "")).To(Succeed()) 105 Expect(clientTracer.chosen).To(Equal(expectedVersion)) 106 Expect(clientTracer.receivedVersionNegotiation).To(BeFalse()) 107 Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions)) 108 Expect(clientTracer.serverVersions).To(BeEmpty()) 109 Expect(serverTracer.chosen).To(Equal(expectedVersion)) 110 Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) 111 Expect(serverTracer.clientVersions).To(BeEmpty()) 112 }) 113 114 It("when the client supports more versions than the server supports", func() { 115 expectedVersion := protocol.SupportedVersions[0] 116 // the server doesn't support the highest supported version, which is the first one the client will try 117 // but it supports a bunch of versions that the client doesn't speak 118 serverTracer := &versionNegotiationTracer{} 119 serverConfig := &quic.Config{} 120 serverConfig.Versions = supportedVersions 121 serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { 122 return serverTracer 123 } 124 server, cl := startServer(getTLSConfig(), serverConfig) 125 defer cl() 126 clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} 127 clientTracer := &versionNegotiationTracer{} 128 conn, err := quic.DialAddr( 129 context.Background(), 130 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 131 getTLSClientConfig(), 132 maybeAddQLOGTracer(&quic.Config{ 133 Versions: clientVersions, 134 Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { 135 return clientTracer 136 }, 137 }), 138 ) 139 Expect(err).ToNot(HaveOccurred()) 140 Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0])) 141 Expect(conn.CloseWithError(0, "")).To(Succeed()) 142 Expect(clientTracer.chosen).To(Equal(expectedVersion)) 143 Expect(clientTracer.receivedVersionNegotiation).To(BeTrue()) 144 Expect(clientTracer.clientVersions).To(Equal(clientVersions)) 145 Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions 146 Expect(serverTracer.chosen).To(Equal(expectedVersion)) 147 Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) 148 Expect(serverTracer.clientVersions).To(BeEmpty()) 149 }) 150 } 151 })