github.com/MerlinKodo/quic-go@v0.39.2/client_test.go (about) 1 package quic 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "net" 8 "time" 9 10 mocklogging "github.com/MerlinKodo/quic-go/internal/mocks/logging" 11 "github.com/MerlinKodo/quic-go/internal/protocol" 12 "github.com/MerlinKodo/quic-go/internal/utils" 13 "github.com/MerlinKodo/quic-go/logging" 14 15 . "github.com/onsi/ginkgo/v2" 16 . "github.com/onsi/gomega" 17 "go.uber.org/mock/gomock" 18 ) 19 20 type nullMultiplexer struct{} 21 22 func (n nullMultiplexer) AddConn(indexableConn) {} 23 func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil } 24 25 var _ = Describe("Client", func() { 26 var ( 27 cl *client 28 packetConn *MockSendConn 29 connID protocol.ConnectionID 30 origMultiplexer multiplexer 31 tlsConf *tls.Config 32 tracer *mocklogging.MockConnectionTracer 33 config *Config 34 35 originalClientConnConstructor func( 36 conn sendConn, 37 runner connRunner, 38 destConnID protocol.ConnectionID, 39 srcConnID protocol.ConnectionID, 40 connIDGenerator ConnectionIDGenerator, 41 conf *Config, 42 tlsConf *tls.Config, 43 initialPacketNumber protocol.PacketNumber, 44 enable0RTT bool, 45 hasNegotiatedVersion bool, 46 tracer *logging.ConnectionTracer, 47 tracingID uint64, 48 logger utils.Logger, 49 v protocol.VersionNumber, 50 ) quicConn 51 ) 52 53 BeforeEach(func() { 54 tlsConf = &tls.Config{NextProtos: []string{"proto1"}} 55 connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37}) 56 originalClientConnConstructor = newClientConnection 57 var tr *logging.ConnectionTracer 58 tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) 59 config = &Config{ 60 Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) *logging.ConnectionTracer { 61 return tr 62 }, 63 Versions: []protocol.VersionNumber{protocol.Version1}, 64 } 65 Eventually(areConnsRunning).Should(BeFalse()) 66 packetConn = NewMockSendConn(mockCtrl) 67 packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() 68 packetConn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes() 69 cl = &client{ 70 srcConnID: connID, 71 destConnID: connID, 72 version: protocol.Version1, 73 sendConn: packetConn, 74 tracer: tr, 75 logger: utils.DefaultLogger, 76 } 77 getMultiplexer() // make the sync.Once execute 78 // replace the clientMuxer. getMultiplexer will now return the nullMultiplexer 79 origMultiplexer = connMuxer 80 connMuxer = &nullMultiplexer{} 81 }) 82 83 AfterEach(func() { 84 connMuxer = origMultiplexer 85 newClientConnection = originalClientConnConstructor 86 }) 87 88 AfterEach(func() { 89 if s, ok := cl.conn.(*connection); ok { 90 s.shutdown() 91 } 92 Eventually(areConnsRunning).Should(BeFalse()) 93 }) 94 95 Context("Dialing", func() { 96 var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error) 97 98 BeforeEach(func() { 99 origGenerateConnectionIDForInitial = generateConnectionIDForInitial 100 generateConnectionIDForInitial = func() (protocol.ConnectionID, error) { 101 return connID, nil 102 } 103 }) 104 105 AfterEach(func() { 106 generateConnectionIDForInitial = origGenerateConnectionIDForInitial 107 }) 108 109 It("returns after the handshake is complete", func() { 110 manager := NewMockPacketHandlerManager(mockCtrl) 111 manager.EXPECT().Add(gomock.Any(), gomock.Any()) 112 113 run := make(chan struct{}) 114 newClientConnection = func( 115 _ sendConn, 116 _ connRunner, 117 _ protocol.ConnectionID, 118 _ protocol.ConnectionID, 119 _ ConnectionIDGenerator, 120 _ *Config, 121 _ *tls.Config, 122 _ protocol.PacketNumber, 123 enable0RTT bool, 124 _ bool, 125 _ *logging.ConnectionTracer, 126 _ uint64, 127 _ utils.Logger, 128 _ protocol.VersionNumber, 129 ) quicConn { 130 Expect(enable0RTT).To(BeFalse()) 131 conn := NewMockQUICConn(mockCtrl) 132 conn.EXPECT().run().Do(func() { close(run) }) 133 c := make(chan struct{}) 134 close(c) 135 conn.EXPECT().HandshakeComplete().Return(c) 136 return conn 137 } 138 cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false) 139 Expect(err).ToNot(HaveOccurred()) 140 cl.packetHandlers = manager 141 Expect(cl).ToNot(BeNil()) 142 Expect(cl.dial(context.Background())).To(Succeed()) 143 Eventually(run).Should(BeClosed()) 144 }) 145 146 It("returns early connections", func() { 147 manager := NewMockPacketHandlerManager(mockCtrl) 148 manager.EXPECT().Add(gomock.Any(), gomock.Any()) 149 readyChan := make(chan struct{}) 150 done := make(chan struct{}) 151 newClientConnection = func( 152 _ sendConn, 153 runner connRunner, 154 _ protocol.ConnectionID, 155 _ protocol.ConnectionID, 156 _ ConnectionIDGenerator, 157 _ *Config, 158 _ *tls.Config, 159 _ protocol.PacketNumber, 160 enable0RTT bool, 161 _ bool, 162 _ *logging.ConnectionTracer, 163 _ uint64, 164 _ utils.Logger, 165 _ protocol.VersionNumber, 166 ) quicConn { 167 Expect(enable0RTT).To(BeTrue()) 168 conn := NewMockQUICConn(mockCtrl) 169 conn.EXPECT().run().Do(func() { close(done) }) 170 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 171 conn.EXPECT().earlyConnReady().Return(readyChan) 172 return conn 173 } 174 175 cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true) 176 Expect(err).ToNot(HaveOccurred()) 177 cl.packetHandlers = manager 178 Expect(cl).ToNot(BeNil()) 179 Expect(cl.dial(context.Background())).To(Succeed()) 180 Eventually(done).Should(BeClosed()) 181 }) 182 183 It("returns an error that occurs while waiting for the handshake to complete", func() { 184 manager := NewMockPacketHandlerManager(mockCtrl) 185 manager.EXPECT().Add(gomock.Any(), gomock.Any()) 186 187 testErr := errors.New("early handshake error") 188 newClientConnection = func( 189 _ sendConn, 190 _ connRunner, 191 _ protocol.ConnectionID, 192 _ protocol.ConnectionID, 193 _ ConnectionIDGenerator, 194 _ *Config, 195 _ *tls.Config, 196 _ protocol.PacketNumber, 197 _ bool, 198 _ bool, 199 _ *logging.ConnectionTracer, 200 _ uint64, 201 _ utils.Logger, 202 _ protocol.VersionNumber, 203 ) quicConn { 204 conn := NewMockQUICConn(mockCtrl) 205 conn.EXPECT().run().Return(testErr) 206 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 207 conn.EXPECT().earlyConnReady().Return(make(chan struct{})) 208 return conn 209 } 210 var closed bool 211 cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true) 212 Expect(err).ToNot(HaveOccurred()) 213 cl.packetHandlers = manager 214 Expect(cl).ToNot(BeNil()) 215 Expect(cl.dial(context.Background())).To(MatchError(testErr)) 216 Expect(closed).To(BeTrue()) 217 }) 218 219 Context("quic.Config", func() { 220 It("setups with the right values", func() { 221 tokenStore := NewLRUTokenStore(10, 4) 222 config := &Config{ 223 HandshakeIdleTimeout: 1337 * time.Minute, 224 MaxIdleTimeout: 42 * time.Hour, 225 MaxIncomingStreams: 1234, 226 MaxIncomingUniStreams: 4321, 227 TokenStore: tokenStore, 228 EnableDatagrams: true, 229 } 230 c := populateConfig(config) 231 Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute)) 232 Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour)) 233 Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) 234 Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) 235 Expect(c.TokenStore).To(Equal(tokenStore)) 236 Expect(c.EnableDatagrams).To(BeTrue()) 237 }) 238 239 It("disables bidirectional streams", func() { 240 config := &Config{ 241 MaxIncomingStreams: -1, 242 MaxIncomingUniStreams: 4321, 243 } 244 c := populateConfig(config) 245 Expect(c.MaxIncomingStreams).To(BeZero()) 246 Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) 247 }) 248 249 It("disables unidirectional streams", func() { 250 config := &Config{ 251 MaxIncomingStreams: 1234, 252 MaxIncomingUniStreams: -1, 253 } 254 c := populateConfig(config) 255 Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) 256 Expect(c.MaxIncomingUniStreams).To(BeZero()) 257 }) 258 259 It("fills in default values if options are not set in the Config", func() { 260 c := populateConfig(&Config{}) 261 Expect(c.Versions).To(Equal(protocol.SupportedVersions)) 262 Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) 263 Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) 264 }) 265 }) 266 267 It("creates new connections with the right parameters", func() { 268 config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}} 269 c := make(chan struct{}) 270 var version protocol.VersionNumber 271 var conf *Config 272 done := make(chan struct{}) 273 newClientConnection = func( 274 connP sendConn, 275 _ connRunner, 276 _ protocol.ConnectionID, 277 _ protocol.ConnectionID, 278 _ ConnectionIDGenerator, 279 configP *Config, 280 _ *tls.Config, 281 _ protocol.PacketNumber, 282 _ bool, 283 _ bool, 284 _ *logging.ConnectionTracer, 285 _ uint64, 286 _ utils.Logger, 287 versionP protocol.VersionNumber, 288 ) quicConn { 289 version = versionP 290 conf = configP 291 close(c) 292 // TODO: check connection IDs? 293 conn := NewMockQUICConn(mockCtrl) 294 conn.EXPECT().run() 295 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 296 conn.EXPECT().destroy(gomock.Any()).MaxTimes(1) 297 close(done) 298 return conn 299 } 300 packetConn := NewMockPacketConn(mockCtrl) 301 packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) { 302 <-done 303 return 0, nil, errors.New("closed") 304 }) 305 packetConn.EXPECT().LocalAddr() 306 packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes() 307 _, err := Dial(context.Background(), packetConn, &net.UDPAddr{}, tlsConf, config) 308 Expect(err).ToNot(HaveOccurred()) 309 Eventually(c).Should(BeClosed()) 310 Expect(version).To(Equal(config.Versions[0])) 311 Expect(conf.Versions).To(Equal(config.Versions)) 312 }) 313 314 It("creates a new connections after version negotiation", func() { 315 var counter int 316 newClientConnection = func( 317 _ sendConn, 318 runner connRunner, 319 _ protocol.ConnectionID, 320 connID protocol.ConnectionID, 321 _ ConnectionIDGenerator, 322 configP *Config, 323 _ *tls.Config, 324 pn protocol.PacketNumber, 325 _ bool, 326 hasNegotiatedVersion bool, 327 _ *logging.ConnectionTracer, 328 _ uint64, 329 _ utils.Logger, 330 versionP protocol.VersionNumber, 331 ) quicConn { 332 conn := NewMockQUICConn(mockCtrl) 333 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 334 if counter == 0 { 335 Expect(pn).To(BeZero()) 336 Expect(hasNegotiatedVersion).To(BeFalse()) 337 conn.EXPECT().run().DoAndReturn(func() error { 338 runner.Remove(connID) 339 return &errCloseForRecreating{ 340 nextPacketNumber: 109, 341 nextVersion: 789, 342 } 343 }) 344 } else { 345 Expect(pn).To(Equal(protocol.PacketNumber(109))) 346 Expect(hasNegotiatedVersion).To(BeTrue()) 347 conn.EXPECT().run() 348 conn.EXPECT().destroy(gomock.Any()) 349 } 350 counter++ 351 return conn 352 } 353 354 config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}} 355 tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) 356 _, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config) 357 Expect(err).ToNot(HaveOccurred()) 358 Expect(counter).To(Equal(2)) 359 }) 360 }) 361 })