github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/integrationtests/versionnegotiation/handshake_test.go (about) 1 package versionnegotiation 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "net" 9 "time" 10 11 "github.com/apernet/quic-go" 12 "github.com/apernet/quic-go/integrationtests/tools/israce" 13 "github.com/apernet/quic-go/internal/protocol" 14 "github.com/apernet/quic-go/logging" 15 16 . "github.com/onsi/ginkgo/v2" 17 . "github.com/onsi/gomega" 18 ) 19 20 type versioner interface { 21 GetVersion() protocol.Version 22 } 23 24 type result struct { 25 loggedVersions bool 26 receivedVersionNegotiation bool 27 chosen logging.VersionNumber 28 clientVersions, serverVersions []logging.VersionNumber 29 } 30 31 func newVersionNegotiationTracer() (*result, *logging.ConnectionTracer) { 32 r := &result{} 33 return r, &logging.ConnectionTracer{ 34 NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { 35 if r.loggedVersions { 36 Fail("only expected one call to NegotiatedVersions") 37 } 38 r.loggedVersions = true 39 r.chosen = chosen 40 r.clientVersions = clientVersions 41 r.serverVersions = serverVersions 42 }, 43 ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { 44 r.receivedVersionNegotiation = true 45 }, 46 } 47 } 48 49 var _ = Describe("Handshake tests", func() { 50 startServer := func(tlsConf *tls.Config, conf *quic.Config) (*quic.Listener, func()) { 51 server, err := quic.ListenAddr("localhost:0", tlsConf, conf) 52 Expect(err).ToNot(HaveOccurred()) 53 54 acceptStopped := make(chan struct{}) 55 go func() { 56 defer GinkgoRecover() 57 defer close(acceptStopped) 58 for { 59 if _, err := server.Accept(context.Background()); err != nil { 60 return 61 } 62 } 63 }() 64 65 return server, func() { 66 server.Close() 67 <-acceptStopped 68 } 69 } 70 71 var supportedVersions []protocol.Version 72 73 BeforeEach(func() { 74 supportedVersions = append([]quic.Version{}, protocol.SupportedVersions...) 75 protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.Version{7, 8, 9, 10}...) 76 }) 77 78 AfterEach(func() { 79 protocol.SupportedVersions = supportedVersions 80 }) 81 82 if !israce.Enabled { 83 It("when the server supports more versions than the client", func() { 84 expectedVersion := protocol.SupportedVersions[0] 85 // the server doesn't support the highest supported version, which is the first one the client will try 86 // but it supports a bunch of versions that the client doesn't speak 87 serverConfig := &quic.Config{} 88 serverConfig.Versions = []protocol.Version{7, 8, protocol.SupportedVersions[0], 9} 89 serverResult, serverTracer := newVersionNegotiationTracer() 90 serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { 91 return serverTracer 92 } 93 server, cl := startServer(getTLSConfig(), serverConfig) 94 defer cl() 95 clientResult, clientTracer := newVersionNegotiationTracer() 96 conn, err := quic.DialAddr( 97 context.Background(), 98 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 99 getTLSClientConfig(), 100 maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) *logging.ConnectionTracer { 101 return clientTracer 102 }}), 103 ) 104 Expect(err).ToNot(HaveOccurred()) 105 Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion)) 106 Expect(conn.CloseWithError(0, "")).To(Succeed()) 107 Expect(clientResult.chosen).To(Equal(expectedVersion)) 108 Expect(clientResult.receivedVersionNegotiation).To(BeFalse()) 109 Expect(clientResult.clientVersions).To(Equal(protocol.SupportedVersions)) 110 Expect(clientResult.serverVersions).To(BeEmpty()) 111 Expect(serverResult.chosen).To(Equal(expectedVersion)) 112 Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions)) 113 Expect(serverResult.clientVersions).To(BeEmpty()) 114 }) 115 116 It("when the client supports more versions than the server supports", func() { 117 expectedVersion := protocol.SupportedVersions[0] 118 // The server doesn't support the highest supported version, which is the first one the client will try, 119 // but it supports a bunch of versions that the client doesn't speak 120 serverResult, serverTracer := newVersionNegotiationTracer() 121 serverConfig := &quic.Config{} 122 serverConfig.Versions = supportedVersions 123 serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { 124 return serverTracer 125 } 126 server, cl := startServer(getTLSConfig(), serverConfig) 127 defer cl() 128 clientVersions := []protocol.Version{7, 8, 9, protocol.SupportedVersions[0], 10} 129 clientResult, clientTracer := newVersionNegotiationTracer() 130 conn, err := quic.DialAddr( 131 context.Background(), 132 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 133 getTLSClientConfig(), 134 maybeAddQLOGTracer(&quic.Config{ 135 Versions: clientVersions, 136 Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { 137 return clientTracer 138 }, 139 }), 140 ) 141 Expect(err).ToNot(HaveOccurred()) 142 Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0])) 143 Expect(conn.CloseWithError(0, "")).To(Succeed()) 144 Expect(clientResult.chosen).To(Equal(expectedVersion)) 145 Expect(clientResult.receivedVersionNegotiation).To(BeTrue()) 146 Expect(clientResult.clientVersions).To(Equal(clientVersions)) 147 Expect(clientResult.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions 148 Expect(serverResult.chosen).To(Equal(expectedVersion)) 149 Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions)) 150 Expect(serverResult.clientVersions).To(BeEmpty()) 151 }) 152 153 It("fails if the server disables version negotiation", func() { 154 // The server doesn't support the highest supported version, which is the first one the client will try, 155 // but it supports a bunch of versions that the client doesn't speak 156 _, serverTracer := newVersionNegotiationTracer() 157 serverConfig := &quic.Config{} 158 serverConfig.Versions = supportedVersions 159 serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { 160 return serverTracer 161 } 162 conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) 163 Expect(err).ToNot(HaveOccurred()) 164 tr := &quic.Transport{ 165 Conn: conn, 166 DisableVersionNegotiationPackets: true, 167 } 168 ln, err := tr.Listen(getTLSConfig(), serverConfig) 169 Expect(err).ToNot(HaveOccurred()) 170 defer ln.Close() 171 172 clientVersions := []protocol.Version{7, 8, 9, protocol.SupportedVersions[0], 10} 173 clientResult, clientTracer := newVersionNegotiationTracer() 174 _, err = quic.DialAddr( 175 context.Background(), 176 fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port), 177 getTLSClientConfig(), 178 maybeAddQLOGTracer(&quic.Config{ 179 Versions: clientVersions, 180 Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { 181 return clientTracer 182 }, 183 HandshakeIdleTimeout: 100 * time.Millisecond, 184 }), 185 ) 186 Expect(err).To(HaveOccurred()) 187 var nerr net.Error 188 Expect(errors.As(err, &nerr)).To(BeTrue()) 189 Expect(nerr.Timeout()).To(BeTrue()) 190 Expect(clientResult.receivedVersionNegotiation).To(BeFalse()) 191 }) 192 } 193 })