github.com/tumi8/quic-go@v0.37.4-tum/client_test.go (about) 1 package quic 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "net" 8 "time" 9 10 mocklogging "github.com/tumi8/quic-go/noninternal/mocks/logging" 11 "github.com/tumi8/quic-go/noninternal/protocol" 12 "github.com/tumi8/quic-go/noninternal/utils" 13 "github.com/tumi8/quic-go/logging" 14 15 "github.com/golang/mock/gomock" 16 17 . "github.com/onsi/ginkgo/v2" 18 . "github.com/onsi/gomega" 19 ) 20 21 type nullMultiplexer struct{} 22 23 func (n nullMultiplexer) AddConn(indexableConn) {} 24 func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil } 25 26 var _ = Describe("Client", func() { 27 var ( 28 cl *client 29 packetConn *MockSendConn 30 connID protocol.ConnectionID 31 origMultiplexer multiplexer 32 tlsConf *tls.Config 33 tracer *mocklogging.MockConnectionTracer 34 config *Config 35 36 originalClientConnConstructor func( 37 conn sendConn, 38 runner connRunner, 39 destConnID protocol.ConnectionID, 40 srcConnID protocol.ConnectionID, 41 connIDGenerator ConnectionIDGenerator, 42 conf *Config, 43 tlsConf *tls.Config, 44 initialPacketNumber protocol.PacketNumber, 45 enable0RTT bool, 46 hasNegotiatedVersion bool, 47 tracer logging.ConnectionTracer, 48 tracingID uint64, 49 logger utils.Logger, 50 v protocol.VersionNumber, 51 ) quicConn 52 ) 53 54 BeforeEach(func() { 55 tlsConf = &tls.Config{NextProtos: []string{"proto1"}} 56 connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37}) 57 originalClientConnConstructor = newClientConnection 58 tracer = mocklogging.NewMockConnectionTracer(mockCtrl) 59 config = &Config{ 60 Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) logging.ConnectionTracer { 61 return tracer 62 }, 63 Versions: []protocol.VersionNumber{protocol.Version1}, 64 } 65 Eventually(areConnsRunning).Should(BeFalse()) 66 packetConn = NewMockSendConn(mockCtrl) 67 packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() 68 packetConn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes() 69 cl = &client{ 70 srcConnID: connID, 71 destConnID: connID, 72 version: protocol.Version1, 73 sendConn: packetConn, 74 tracer: tracer, 75 logger: utils.DefaultLogger, 76 } 77 getMultiplexer() // make the sync.Once execute 78 // replace the clientMuxer. getMultiplexer will now return the nullMultiplexer 79 origMultiplexer = connMuxer 80 connMuxer = &nullMultiplexer{} 81 }) 82 83 AfterEach(func() { 84 connMuxer = origMultiplexer 85 newClientConnection = originalClientConnConstructor 86 }) 87 88 AfterEach(func() { 89 if s, ok := cl.conn.(*connection); ok { 90 s.shutdown() 91 } 92 Eventually(areConnsRunning).Should(BeFalse()) 93 }) 94 95 Context("Dialing", func() { 96 var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error) 97 98 BeforeEach(func() { 99 origGenerateConnectionIDForInitial = generateConnectionIDForInitial 100 generateConnectionIDForInitial = func() (protocol.ConnectionID, error) { 101 return connID, nil 102 } 103 }) 104 105 AfterEach(func() { 106 generateConnectionIDForInitial = origGenerateConnectionIDForInitial 107 }) 108 109 It("returns after the handshake is complete", func() { 110 manager := NewMockPacketHandlerManager(mockCtrl) 111 manager.EXPECT().Add(gomock.Any(), gomock.Any()) 112 113 run := make(chan struct{}) 114 newClientConnection = func( 115 _ sendConn, 116 _ connRunner, 117 _ protocol.ConnectionID, 118 _ protocol.ConnectionID, 119 _ ConnectionIDGenerator, 120 _ *Config, 121 _ *tls.Config, 122 _ protocol.PacketNumber, 123 enable0RTT bool, 124 _ bool, 125 _ logging.ConnectionTracer, 126 _ uint64, 127 _ utils.Logger, 128 _ protocol.VersionNumber, 129 ) quicConn { 130 Expect(enable0RTT).To(BeFalse()) 131 conn := NewMockQUICConn(mockCtrl) 132 conn.EXPECT().run().Do(func() { close(run) }) 133 c := make(chan struct{}) 134 close(c) 135 conn.EXPECT().HandshakeComplete().Return(c) 136 return conn 137 } 138 cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false) 139 Expect(err).ToNot(HaveOccurred()) 140 cl.packetHandlers = manager 141 Expect(cl).ToNot(BeNil()) 142 Expect(cl.dial(context.Background())).To(Succeed()) 143 Eventually(run).Should(BeClosed()) 144 }) 145 146 It("returns early connections", func() { 147 manager := NewMockPacketHandlerManager(mockCtrl) 148 manager.EXPECT().Add(gomock.Any(), gomock.Any()) 149 readyChan := make(chan struct{}) 150 done := make(chan struct{}) 151 newClientConnection = func( 152 _ sendConn, 153 runner connRunner, 154 _ protocol.ConnectionID, 155 _ protocol.ConnectionID, 156 _ ConnectionIDGenerator, 157 _ *Config, 158 _ *tls.Config, 159 _ protocol.PacketNumber, 160 enable0RTT bool, 161 _ bool, 162 _ logging.ConnectionTracer, 163 _ uint64, 164 _ utils.Logger, 165 _ protocol.VersionNumber, 166 ) quicConn { 167 Expect(enable0RTT).To(BeTrue()) 168 conn := NewMockQUICConn(mockCtrl) 169 conn.EXPECT().run().Do(func() { close(done) }) 170 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 171 conn.EXPECT().earlyConnReady().Return(readyChan) 172 return conn 173 } 174 175 cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true) 176 Expect(err).ToNot(HaveOccurred()) 177 cl.packetHandlers = manager 178 Expect(cl).ToNot(BeNil()) 179 Expect(cl.dial(context.Background())).To(Succeed()) 180 Eventually(done).Should(BeClosed()) 181 }) 182 183 It("returns an error that occurs while waiting for the handshake to complete", func() { 184 manager := NewMockPacketHandlerManager(mockCtrl) 185 manager.EXPECT().Add(gomock.Any(), gomock.Any()) 186 187 testErr := errors.New("early handshake error") 188 newClientConnection = func( 189 _ sendConn, 190 _ connRunner, 191 _ protocol.ConnectionID, 192 _ protocol.ConnectionID, 193 _ ConnectionIDGenerator, 194 _ *Config, 195 _ *tls.Config, 196 _ protocol.PacketNumber, 197 _ bool, 198 _ bool, 199 _ logging.ConnectionTracer, 200 _ uint64, 201 _ utils.Logger, 202 _ protocol.VersionNumber, 203 ) quicConn { 204 conn := NewMockQUICConn(mockCtrl) 205 conn.EXPECT().run().Return(testErr) 206 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 207 conn.EXPECT().earlyConnReady().Return(make(chan struct{})) 208 return conn 209 } 210 var closed bool 211 cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true) 212 Expect(err).ToNot(HaveOccurred()) 213 cl.packetHandlers = manager 214 Expect(cl).ToNot(BeNil()) 215 Expect(cl.dial(context.Background())).To(MatchError(testErr)) 216 Expect(closed).To(BeTrue()) 217 }) 218 219 Context("quic.Config", func() { 220 It("setups with the right values", func() { 221 tokenStore := NewLRUTokenStore(10, 4) 222 config := &Config{ 223 HandshakeIdleTimeout: 1337 * time.Minute, 224 MaxIdleTimeout: 42 * time.Hour, 225 MaxIncomingStreams: 1234, 226 MaxIncomingUniStreams: 4321, 227 TokenStore: tokenStore, 228 EnableDatagrams: true, 229 } 230 c := populateConfig(config) 231 Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute)) 232 Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour)) 233 Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) 234 Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) 235 Expect(c.TokenStore).To(Equal(tokenStore)) 236 Expect(c.EnableDatagrams).To(BeTrue()) 237 }) 238 239 It("disables bidirectional streams", func() { 240 config := &Config{ 241 MaxIncomingStreams: -1, 242 MaxIncomingUniStreams: 4321, 243 } 244 c := populateConfig(config) 245 Expect(c.MaxIncomingStreams).To(BeZero()) 246 Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) 247 }) 248 249 It("disables unidirectional streams", func() { 250 config := &Config{ 251 MaxIncomingStreams: 1234, 252 MaxIncomingUniStreams: -1, 253 } 254 c := populateConfig(config) 255 Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) 256 Expect(c.MaxIncomingUniStreams).To(BeZero()) 257 }) 258 259 It("fills in default values if options are not set in the Config", func() { 260 c := populateConfig(&Config{}) 261 Expect(c.Versions).To(Equal(protocol.SupportedVersions)) 262 Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) 263 Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) 264 }) 265 }) 266 267 It("creates new connections with the right parameters", func() { 268 config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}} 269 c := make(chan struct{}) 270 var version protocol.VersionNumber 271 var conf *Config 272 done := make(chan struct{}) 273 newClientConnection = func( 274 connP sendConn, 275 _ connRunner, 276 _ protocol.ConnectionID, 277 _ protocol.ConnectionID, 278 _ ConnectionIDGenerator, 279 configP *Config, 280 _ *tls.Config, 281 _ protocol.PacketNumber, 282 _ bool, 283 _ bool, 284 _ logging.ConnectionTracer, 285 _ uint64, 286 _ utils.Logger, 287 versionP protocol.VersionNumber, 288 ) quicConn { 289 version = versionP 290 conf = configP 291 close(c) 292 // TODO: check connection IDs? 293 conn := NewMockQUICConn(mockCtrl) 294 conn.EXPECT().run() 295 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 296 conn.EXPECT().destroy(gomock.Any()).MaxTimes(1) 297 close(done) 298 return conn 299 } 300 packetConn := NewMockPacketConn(mockCtrl) 301 packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) { 302 <-done 303 return 0, nil, errors.New("closed") 304 }) 305 packetConn.EXPECT().LocalAddr() 306 packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes() 307 _, err := Dial(context.Background(), packetConn, &net.UDPAddr{}, tlsConf, config) 308 Expect(err).ToNot(HaveOccurred()) 309 Eventually(c).Should(BeClosed()) 310 Expect(version).To(Equal(config.Versions[0])) 311 Expect(conf.Versions).To(Equal(config.Versions)) 312 }) 313 314 It("creates a new connections after version negotiation", func() { 315 var counter int 316 newClientConnection = func( 317 _ sendConn, 318 runner connRunner, 319 _ protocol.ConnectionID, 320 connID protocol.ConnectionID, 321 _ ConnectionIDGenerator, 322 configP *Config, 323 _ *tls.Config, 324 pn protocol.PacketNumber, 325 _ bool, 326 hasNegotiatedVersion bool, 327 _ logging.ConnectionTracer, 328 _ uint64, 329 _ utils.Logger, 330 versionP protocol.VersionNumber, 331 ) quicConn { 332 conn := NewMockQUICConn(mockCtrl) 333 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 334 if counter == 0 { 335 Expect(pn).To(BeZero()) 336 Expect(hasNegotiatedVersion).To(BeFalse()) 337 conn.EXPECT().run().DoAndReturn(func() error { 338 runner.Remove(connID) 339 return &errCloseForRecreating{ 340 nextPacketNumber: 109, 341 nextVersion: 789, 342 } 343 }) 344 } else { 345 Expect(pn).To(Equal(protocol.PacketNumber(109))) 346 Expect(hasNegotiatedVersion).To(BeTrue()) 347 conn.EXPECT().run() 348 conn.EXPECT().destroy(gomock.Any()) 349 } 350 counter++ 351 return conn 352 } 353 354 config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}} 355 tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) 356 _, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config) 357 Expect(err).ToNot(HaveOccurred()) 358 Expect(counter).To(Equal(2)) 359 }) 360 }) 361 })