github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/server_test.go (about) 1 package quic 2 3 import ( 4 "context" 5 "crypto/rand" 6 "crypto/tls" 7 "errors" 8 "net" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "golang.org/x/time/rate" 14 15 "github.com/daeuniverse/quic-go/internal/handshake" 16 mocklogging "github.com/daeuniverse/quic-go/internal/mocks/logging" 17 "github.com/daeuniverse/quic-go/internal/protocol" 18 "github.com/daeuniverse/quic-go/internal/qerr" 19 "github.com/daeuniverse/quic-go/internal/testdata" 20 "github.com/daeuniverse/quic-go/internal/utils" 21 "github.com/daeuniverse/quic-go/internal/wire" 22 "github.com/daeuniverse/quic-go/logging" 23 24 . "github.com/onsi/ginkgo/v2" 25 . "github.com/onsi/gomega" 26 "go.uber.org/mock/gomock" 27 ) 28 29 var _ = Describe("Server", func() { 30 var ( 31 conn *MockPacketConn 32 tlsConf *tls.Config 33 ) 34 35 getPacket := func(hdr *wire.Header, p []byte) receivedPacket { 36 buf := getPacketBuffer() 37 hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 38 var err error 39 buf.Data, err = (&wire.ExtendedHeader{ 40 Header: *hdr, 41 PacketNumber: 0x42, 42 PacketNumberLen: protocol.PacketNumberLen4, 43 }).Append(buf.Data, protocol.Version1) 44 Expect(err).ToNot(HaveOccurred()) 45 n := len(buf.Data) 46 buf.Data = append(buf.Data, p...) 47 data := buf.Data 48 sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) 49 _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) 50 data = data[:len(data)+16] 51 sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n]) 52 return receivedPacket{ 53 rcvTime: time.Now(), 54 remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, 55 data: data, 56 buffer: buf, 57 } 58 } 59 60 getInitial := func(destConnID protocol.ConnectionID) receivedPacket { 61 senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} 62 hdr := &wire.Header{ 63 Type: protocol.PacketTypeInitial, 64 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 65 DestConnectionID: destConnID, 66 Version: protocol.Version1, 67 } 68 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 69 p.buffer = getPacketBuffer() 70 p.remoteAddr = senderAddr 71 return p 72 } 73 74 getInitialWithRandomDestConnID := func() receivedPacket { 75 b := make([]byte, 10) 76 _, err := rand.Read(b) 77 Expect(err).ToNot(HaveOccurred()) 78 79 return getInitial(protocol.ParseConnectionID(b)) 80 } 81 82 parseHeader := func(data []byte) *wire.Header { 83 hdr, _, _, err := wire.ParsePacket(data) 84 Expect(err).ToNot(HaveOccurred()) 85 return hdr 86 } 87 88 checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) { 89 replyHdr := parseHeader(b) 90 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 91 Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID)) 92 Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID)) 93 _, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) 94 extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version) 95 Expect(err).ToNot(HaveOccurred()) 96 data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) 97 Expect(err).ToNot(HaveOccurred()) 98 _, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version) 99 Expect(err).ToNot(HaveOccurred()) 100 Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 101 ccf := f.(*wire.ConnectionCloseFrame) 102 Expect(ccf.IsApplicationError).To(BeFalse()) 103 Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode)) 104 Expect(ccf.ReasonPhrase).To(BeEmpty()) 105 } 106 107 BeforeEach(func() { 108 conn = NewMockPacketConn(mockCtrl) 109 conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() 110 wait := make(chan struct{}) 111 conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) { 112 <-wait 113 return 0, nil, errors.New("done") 114 }).MaxTimes(1) 115 conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) error { 116 close(wait) 117 conn.EXPECT().SetReadDeadline(time.Time{}) 118 return nil 119 }).MaxTimes(1) 120 tlsConf = testdata.GetTLSConfig() 121 tlsConf.NextProtos = []string{"proto1"} 122 }) 123 124 It("errors when no tls.Config is given", func() { 125 _, err := ListenAddr("localhost:0", nil, nil) 126 Expect(err).To(HaveOccurred()) 127 Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set")) 128 }) 129 130 It("errors when the Config contains an invalid version", func() { 131 version := protocol.Version(0x1234) 132 _, err := Listen(nil, tlsConf, &Config{Versions: []protocol.Version{version}}) 133 Expect(err).To(MatchError("invalid QUIC version: 0x1234")) 134 }) 135 136 It("fills in default values if options are not set in the Config", func() { 137 ln, err := Listen(conn, tlsConf, &Config{}) 138 Expect(err).ToNot(HaveOccurred()) 139 server := ln.baseServer 140 Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) 141 Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) 142 Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) 143 Expect(server.config.KeepAlivePeriod).To(BeZero()) 144 // stop the listener 145 Expect(ln.Close()).To(Succeed()) 146 }) 147 148 It("setups with the right values", func() { 149 supportedVersions := []protocol.Version{protocol.Version1} 150 config := Config{ 151 Versions: supportedVersions, 152 HandshakeIdleTimeout: 1337 * time.Hour, 153 MaxIdleTimeout: 42 * time.Minute, 154 KeepAlivePeriod: 5 * time.Second, 155 } 156 ln, err := Listen(conn, tlsConf, &config) 157 Expect(err).ToNot(HaveOccurred()) 158 server := ln.baseServer 159 Expect(server.connHandler).ToNot(BeNil()) 160 Expect(server.config.Versions).To(Equal(supportedVersions)) 161 Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) 162 Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) 163 Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) 164 // stop the listener 165 Expect(ln.Close()).To(Succeed()) 166 }) 167 168 It("listens on a given address", func() { 169 addr := "127.0.0.1:13579" 170 ln, err := ListenAddr(addr, tlsConf, &Config{}) 171 Expect(err).ToNot(HaveOccurred()) 172 Expect(ln.Addr().String()).To(Equal(addr)) 173 // stop the listener 174 Expect(ln.Close()).To(Succeed()) 175 }) 176 177 It("errors if given an invalid address", func() { 178 addr := "127.0.0.1" 179 _, err := ListenAddr(addr, tlsConf, &Config{}) 180 Expect(err).To(BeAssignableToTypeOf(&net.AddrError{})) 181 }) 182 183 It("errors if given an invalid address", func() { 184 addr := "1.1.1.1:1111" 185 _, err := ListenAddr(addr, tlsConf, &Config{}) 186 Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) 187 }) 188 189 Context("server accepting connections that completed the handshake", func() { 190 var ( 191 tr *Transport 192 serv *baseServer 193 phm *MockPacketHandlerManager 194 tracer *mocklogging.MockTracer 195 ) 196 197 BeforeEach(func() { 198 var t *logging.Tracer 199 t, tracer = mocklogging.NewMockTracer(mockCtrl) 200 tr = &Transport{Conn: conn, Tracer: t} 201 ln, err := tr.Listen(tlsConf, nil) 202 Expect(err).ToNot(HaveOccurred()) 203 serv = ln.baseServer 204 phm = NewMockPacketHandlerManager(mockCtrl) 205 serv.connHandler = phm 206 }) 207 208 AfterEach(func() { 209 tracer.EXPECT().Close() 210 tr.Close() 211 }) 212 213 Context("handling packets", func() { 214 It("drops Initial packets with a too short connection ID", func() { 215 p := getPacket(&wire.Header{ 216 Type: protocol.PacketTypeInitial, 217 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), 218 Version: serv.config.Versions[0], 219 }, nil) 220 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) 221 serv.handlePacket(p) 222 // make sure there are no Write calls on the packet conn 223 time.Sleep(50 * time.Millisecond) 224 }) 225 226 It("drops too small Initial", func() { 227 p := getPacket(&wire.Header{ 228 Type: protocol.PacketTypeInitial, 229 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), 230 Version: serv.config.Versions[0], 231 }, make([]byte, protocol.MinInitialPacketSize-100)) 232 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) 233 serv.handlePacket(p) 234 // make sure there are no Write calls on the packet conn 235 time.Sleep(50 * time.Millisecond) 236 }) 237 238 It("drops non-Initial packets", func() { 239 p := getPacket(&wire.Header{ 240 Type: protocol.PacketTypeHandshake, 241 Version: serv.config.Versions[0], 242 }, []byte("invalid")) 243 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket) 244 serv.handlePacket(p) 245 // make sure there are no Write calls on the packet conn 246 time.Sleep(50 * time.Millisecond) 247 }) 248 249 It("passes packets to existing connections", func() { 250 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 251 p := getPacket(&wire.Header{ 252 Type: protocol.PacketTypeInitial, 253 DestConnectionID: connID, 254 Version: serv.config.Versions[0], 255 }, make([]byte, protocol.MinInitialPacketSize)) 256 conn := NewMockPacketHandler(mockCtrl) 257 phm.EXPECT().Get(connID).Return(conn, true) 258 handled := make(chan struct{}) 259 conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) 260 serv.handlePacket(p) 261 Eventually(handled).Should(BeClosed()) 262 }) 263 264 It("creates a connection when the token is accepted", func() { 265 serv.verifySourceAddress = func(net.Addr) bool { return true } 266 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 267 retryToken, err := serv.tokenGenerator.NewRetryToken( 268 raddr, 269 protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), 270 protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), 271 ) 272 Expect(err).ToNot(HaveOccurred()) 273 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 274 hdr := &wire.Header{ 275 Type: protocol.PacketTypeInitial, 276 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 277 DestConnectionID: connID, 278 Version: protocol.Version1, 279 Token: retryToken, 280 } 281 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 282 p.remoteAddr = raddr 283 run := make(chan struct{}) 284 var token protocol.StatelessResetToken 285 rand.Read(token[:]) 286 287 var newConnID protocol.ConnectionID 288 conn := NewMockQUICConn(mockCtrl) 289 serv.newConn = func( 290 _ sendConn, 291 _ connRunner, 292 origDestConnID protocol.ConnectionID, 293 retrySrcConnID *protocol.ConnectionID, 294 clientDestConnID protocol.ConnectionID, 295 destConnID protocol.ConnectionID, 296 srcConnID protocol.ConnectionID, 297 _ ConnectionIDGenerator, 298 tokenP protocol.StatelessResetToken, 299 _ *Config, 300 _ *tls.Config, 301 _ *handshake.TokenGenerator, 302 _ bool, 303 _ *logging.ConnectionTracer, 304 _ uint64, 305 _ utils.Logger, 306 _ protocol.Version, 307 ) quicConn { 308 Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))) 309 Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) 310 Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) 311 Expect(destConnID).To(Equal(hdr.SrcConnectionID)) 312 // make sure we're using a server-generated connection ID 313 Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) 314 Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) 315 newConnID = srcConnID 316 Expect(tokenP).To(Equal(token)) 317 conn.EXPECT().handlePacket(p) 318 conn.EXPECT().run().Do(func() error { close(run); return nil }) 319 conn.EXPECT().Context().Return(context.Background()) 320 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 321 return conn 322 } 323 phm.EXPECT().Get(connID) 324 phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token) 325 phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool { 326 Expect(cid).To(Equal(newConnID)) 327 return true 328 }) 329 330 done := make(chan struct{}) 331 go func() { 332 defer GinkgoRecover() 333 serv.handlePacket(p) 334 // the Handshake packet is written by the connection. 335 // Make sure there are no Write calls on the packet conn. 336 time.Sleep(50 * time.Millisecond) 337 close(done) 338 }() 339 // make sure we're using a server-generated connection ID 340 Eventually(run).Should(BeClosed()) 341 Eventually(done).Should(BeClosed()) 342 // shutdown 343 conn.EXPECT().closeWithTransportError(gomock.Any()) 344 }) 345 346 It("sends a Version Negotiation Packet for unsupported versions", func() { 347 srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) 348 destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) 349 packet := getPacket(&wire.Header{ 350 Type: protocol.PacketTypeHandshake, 351 SrcConnectionID: srcConnID, 352 DestConnectionID: destConnID, 353 Version: 0x42, 354 }, make([]byte, protocol.MinUnknownVersionPacketSize)) 355 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 356 packet.remoteAddr = raddr 357 tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.Version) { 358 Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) 359 Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) 360 }) 361 done := make(chan struct{}) 362 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 363 defer close(done) 364 Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) 365 dest, src, versions, err := wire.ParseVersionNegotiationPacket(b) 366 Expect(err).ToNot(HaveOccurred()) 367 Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) 368 Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) 369 Expect(versions).ToNot(ContainElement(protocol.Version(0x42))) 370 return len(b), nil 371 }) 372 serv.handlePacket(packet) 373 Eventually(done).Should(BeClosed()) 374 }) 375 376 It("doesn't send a Version Negotiation packets if sending them is disabled", func() { 377 serv.disableVersionNegotiation = true 378 srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) 379 destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) 380 packet := getPacket(&wire.Header{ 381 Type: protocol.PacketTypeHandshake, 382 SrcConnectionID: srcConnID, 383 DestConnectionID: destConnID, 384 Version: 0x42, 385 }, make([]byte, protocol.MinUnknownVersionPacketSize)) 386 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 387 packet.remoteAddr = raddr 388 done := make(chan struct{}) 389 serv.handlePacket(packet) 390 Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed()) 391 }) 392 393 It("ignores Version Negotiation packets", func() { 394 data := wire.ComposeVersionNegotiation( 395 protocol.ArbitraryLenConnectionID{1, 2, 3, 4}, 396 protocol.ArbitraryLenConnectionID{4, 3, 2, 1}, 397 []protocol.Version{1, 2, 3}, 398 ) 399 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 400 done := make(chan struct{}) 401 tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 402 close(done) 403 }) 404 serv.handlePacket(receivedPacket{ 405 remoteAddr: raddr, 406 data: data, 407 buffer: getPacketBuffer(), 408 }) 409 Eventually(done).Should(BeClosed()) 410 // make sure no other packet is sent 411 time.Sleep(scaleDuration(20 * time.Millisecond)) 412 }) 413 414 It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() { 415 srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) 416 destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) 417 p := getPacket(&wire.Header{ 418 Type: protocol.PacketTypeHandshake, 419 SrcConnectionID: srcConnID, 420 DestConnectionID: destConnID, 421 Version: 0x42, 422 }, make([]byte, protocol.MinUnknownVersionPacketSize-50)) 423 Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize)) 424 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 425 p.remoteAddr = raddr 426 done := make(chan struct{}) 427 tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 428 close(done) 429 }) 430 serv.handlePacket(p) 431 Eventually(done).Should(BeClosed()) 432 // make sure no other packet is sent 433 time.Sleep(scaleDuration(20 * time.Millisecond)) 434 }) 435 436 It("replies with a Retry packet, if a token is required", func() { 437 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 438 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 439 var called bool 440 serv.verifySourceAddress = func(addr net.Addr) bool { 441 Expect(addr).To(Equal(raddr)) 442 called = true 443 return true 444 } 445 hdr := &wire.Header{ 446 Type: protocol.PacketTypeInitial, 447 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 448 DestConnectionID: connID, 449 Version: protocol.Version1, 450 } 451 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 452 packet.remoteAddr = raddr 453 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { 454 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 455 Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) 456 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 457 Expect(replyHdr.Token).ToNot(BeEmpty()) 458 }) 459 done := make(chan struct{}) 460 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 461 defer close(done) 462 replyHdr := parseHeader(b) 463 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 464 Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) 465 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 466 Expect(replyHdr.Token).ToNot(BeEmpty()) 467 Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:])) 468 return len(b), nil 469 }) 470 phm.EXPECT().Get(connID) 471 serv.handlePacket(packet) 472 Eventually(done).Should(BeClosed()) 473 Expect(called).To(BeTrue()) 474 }) 475 476 It("creates a connection, if no token is required", func() { 477 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 478 hdr := &wire.Header{ 479 Type: protocol.PacketTypeInitial, 480 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 481 DestConnectionID: connID, 482 Version: protocol.Version1, 483 } 484 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 485 run := make(chan struct{}) 486 var token protocol.StatelessResetToken 487 rand.Read(token[:]) 488 489 var newConnID protocol.ConnectionID 490 conn := NewMockQUICConn(mockCtrl) 491 serv.newConn = func( 492 _ sendConn, 493 _ connRunner, 494 origDestConnID protocol.ConnectionID, 495 retrySrcConnID *protocol.ConnectionID, 496 clientDestConnID protocol.ConnectionID, 497 destConnID protocol.ConnectionID, 498 srcConnID protocol.ConnectionID, 499 _ ConnectionIDGenerator, 500 tokenP protocol.StatelessResetToken, 501 _ *Config, 502 _ *tls.Config, 503 _ *handshake.TokenGenerator, 504 _ bool, 505 _ *logging.ConnectionTracer, 506 _ uint64, 507 _ utils.Logger, 508 _ protocol.Version, 509 ) quicConn { 510 Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) 511 Expect(retrySrcConnID).To(BeNil()) 512 Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) 513 Expect(destConnID).To(Equal(hdr.SrcConnectionID)) 514 // make sure we're using a server-generated connection ID 515 Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) 516 Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) 517 newConnID = srcConnID 518 Expect(tokenP).To(Equal(token)) 519 conn.EXPECT().handlePacket(p) 520 conn.EXPECT().run().Do(func() error { close(run); return nil }) 521 conn.EXPECT().Context().Return(context.Background()) 522 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 523 return conn 524 } 525 gomock.InOrder( 526 phm.EXPECT().Get(connID), 527 phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token), 528 phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool { 529 Expect(c).To(Equal(newConnID)) 530 return true 531 }), 532 ) 533 534 done := make(chan struct{}) 535 go func() { 536 defer GinkgoRecover() 537 serv.handlePacket(p) 538 // the Handshake packet is written by the connection 539 // make sure there are no Write calls on the packet conn 540 time.Sleep(50 * time.Millisecond) 541 close(done) 542 }() 543 // make sure we're using a server-generated connection ID 544 Eventually(run).Should(BeClosed()) 545 Eventually(done).Should(BeClosed()) 546 // shutdown 547 conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) 548 }) 549 550 It("drops packets if the receive queue is full", func() { 551 serv.verifySourceAddress = func(net.Addr) bool { return false } 552 553 phm.EXPECT().Get(gomock.Any()).AnyTimes() 554 phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() 555 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() 556 557 acceptConn := make(chan struct{}) 558 var counter atomic.Uint32 559 serv.newConn = func( 560 _ sendConn, 561 runner connRunner, 562 _ protocol.ConnectionID, 563 _ *protocol.ConnectionID, 564 _ protocol.ConnectionID, 565 _ protocol.ConnectionID, 566 _ protocol.ConnectionID, 567 _ ConnectionIDGenerator, 568 _ protocol.StatelessResetToken, 569 _ *Config, 570 _ *tls.Config, 571 _ *handshake.TokenGenerator, 572 _ bool, 573 _ *logging.ConnectionTracer, 574 _ uint64, 575 _ utils.Logger, 576 _ protocol.Version, 577 ) quicConn { 578 <-acceptConn 579 counter.Add(1) 580 conn := NewMockQUICConn(mockCtrl) 581 conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) 582 conn.EXPECT().run().MaxTimes(1) 583 conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) 584 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1) 585 // shutdown 586 conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) 587 return conn 588 } 589 590 p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})) 591 serv.handlePacket(p) 592 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1) 593 var wg sync.WaitGroup 594 for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ { 595 wg.Add(1) 596 go func() { 597 defer GinkgoRecover() 598 defer wg.Done() 599 serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))) 600 }() 601 } 602 wg.Wait() 603 604 close(acceptConn) 605 Eventually( 606 func() uint32 { return counter.Load() }, 607 scaleDuration(100*time.Millisecond), 608 ).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) 609 Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) 610 }) 611 612 It("only creates a single connection for a duplicate Initial", func() { 613 done := make(chan struct{}) 614 serv.newConn = func( 615 _ sendConn, 616 runner connRunner, 617 _ protocol.ConnectionID, 618 _ *protocol.ConnectionID, 619 _ protocol.ConnectionID, 620 _ protocol.ConnectionID, 621 _ protocol.ConnectionID, 622 _ ConnectionIDGenerator, 623 _ protocol.StatelessResetToken, 624 _ *Config, 625 _ *tls.Config, 626 _ *handshake.TokenGenerator, 627 _ bool, 628 _ *logging.ConnectionTracer, 629 _ uint64, 630 _ utils.Logger, 631 _ protocol.Version, 632 ) quicConn { 633 conn := NewMockQUICConn(mockCtrl) 634 conn.EXPECT().handlePacket(gomock.Any()) 635 conn.EXPECT().closeWithTransportError(qerr.ConnectionRefused).Do(func(qerr.TransportErrorCode) { 636 close(done) 637 }) 638 return conn 639 } 640 641 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) 642 p := getInitial(connID) 643 phm.EXPECT().Get(connID) 644 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 645 phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision 646 Expect(serv.handlePacketImpl(p)).To(BeTrue()) 647 Eventually(done).Should(BeClosed()) 648 }) 649 650 It("limits the number of unvalidated handshakes", func() { 651 const limit = 3 652 limiter := rate.NewLimiter(0, limit) 653 serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() } 654 655 phm.EXPECT().Get(gomock.Any()).AnyTimes() 656 phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() 657 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() 658 659 connChan := make(chan *MockQUICConn, 1) 660 var wg sync.WaitGroup 661 wg.Add(limit) 662 done := make(chan struct{}) 663 serv.newConn = func( 664 _ sendConn, 665 runner connRunner, 666 _ protocol.ConnectionID, 667 _ *protocol.ConnectionID, 668 _ protocol.ConnectionID, 669 _ protocol.ConnectionID, 670 _ protocol.ConnectionID, 671 _ ConnectionIDGenerator, 672 _ protocol.StatelessResetToken, 673 _ *Config, 674 _ *tls.Config, 675 _ *handshake.TokenGenerator, 676 _ bool, 677 _ *logging.ConnectionTracer, 678 _ uint64, 679 _ utils.Logger, 680 _ protocol.Version, 681 ) quicConn { 682 conn := <-connChan 683 conn.EXPECT().handlePacket(gomock.Any()) 684 conn.EXPECT().run() 685 conn.EXPECT().Context().Return(context.Background()) 686 conn.EXPECT().HandshakeComplete().DoAndReturn(func() <-chan struct{} { wg.Done(); return done }) 687 return conn 688 } 689 690 // Initiate the maximum number of allowed connection attempts. 691 for i := 0; i < limit; i++ { 692 conn := NewMockQUICConn(mockCtrl) 693 connChan <- conn 694 serv.handlePacket(getInitialWithRandomDestConnID()) 695 } 696 697 // Now initiate another connection attempt. 698 p := getInitialWithRandomDestConnID() 699 tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 700 defer GinkgoRecover() 701 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 702 }) 703 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 704 defer GinkgoRecover() 705 defer close(done) 706 hdr, _, _, err := wire.ParsePacket(b) 707 Expect(err).ToNot(HaveOccurred()) 708 Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) 709 return len(b), nil 710 }) 711 serv.handlePacket(p) 712 Eventually(done).Should(BeClosed()) 713 714 for i := 0; i < limit; i++ { 715 _, err := serv.Accept(context.Background()) 716 Expect(err).ToNot(HaveOccurred()) 717 } 718 wg.Wait() 719 }) 720 }) 721 722 Context("token validation", func() { 723 It("decodes the token from the token field", func() { 724 serv.newConn = func( 725 _ sendConn, 726 _ connRunner, 727 _ protocol.ConnectionID, 728 _ *protocol.ConnectionID, 729 _ protocol.ConnectionID, 730 _ protocol.ConnectionID, 731 _ protocol.ConnectionID, 732 _ ConnectionIDGenerator, 733 _ protocol.StatelessResetToken, 734 _ *Config, 735 _ *tls.Config, 736 _ *handshake.TokenGenerator, 737 _ bool, 738 _ *logging.ConnectionTracer, 739 _ uint64, 740 _ utils.Logger, 741 _ protocol.Version, 742 ) quicConn { 743 c := NewMockQUICConn(mockCtrl) 744 c.EXPECT().handlePacket(gomock.Any()) 745 c.EXPECT().run() 746 c.EXPECT().HandshakeComplete() 747 ctx, cancel := context.WithCancel(context.Background()) 748 cancel() 749 c.EXPECT().Context().Return(ctx) 750 return c 751 } 752 raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} 753 token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) 754 Expect(err).ToNot(HaveOccurred()) 755 packet := getPacket(&wire.Header{ 756 Type: protocol.PacketTypeInitial, 757 Token: token, 758 Version: serv.config.Versions[0], 759 }, make([]byte, protocol.MinInitialPacketSize)) 760 packet.remoteAddr = raddr 761 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) 762 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) 763 764 done := make(chan struct{}) 765 phm.EXPECT().Get(gomock.Any()) 766 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 767 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool { 768 close(done) 769 return true 770 }) 771 phm.EXPECT().Remove(gomock.Any()).AnyTimes() 772 serv.handlePacket(packet) 773 Eventually(done).Should(BeClosed()) 774 }) 775 776 It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { 777 serv.verifySourceAddress = func(net.Addr) bool { return true } 778 token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) 779 Expect(err).ToNot(HaveOccurred()) 780 hdr := &wire.Header{ 781 Type: protocol.PacketTypeInitial, 782 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 783 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 784 Token: token, 785 Version: protocol.Version1, 786 } 787 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 788 packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet 789 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 790 packet.remoteAddr = raddr 791 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 792 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 793 Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 794 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 795 Expect(frames).To(HaveLen(1)) 796 Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 797 ccf := frames[0].(*logging.ConnectionCloseFrame) 798 Expect(ccf.IsApplicationError).To(BeFalse()) 799 Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) 800 }) 801 done := make(chan struct{}) 802 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 803 defer close(done) 804 checkConnectionCloseError(b, hdr, qerr.InvalidToken) 805 return len(b), nil 806 }) 807 phm.EXPECT().Get(gomock.Any()) 808 serv.handlePacket(packet) 809 Eventually(done).Should(BeClosed()) 810 }) 811 812 It("sends an INVALID_TOKEN error, if an expired retry token is received", func() { 813 serv.verifySourceAddress = func(net.Addr) bool { return true } 814 serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout 815 Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond)) 816 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 817 token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) 818 Expect(err).ToNot(HaveOccurred()) 819 time.Sleep(2 * time.Millisecond) // make sure the token is expired 820 hdr := &wire.Header{ 821 Type: protocol.PacketTypeInitial, 822 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 823 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 824 Token: token, 825 Version: protocol.Version1, 826 } 827 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 828 packet.remoteAddr = raddr 829 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 830 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 831 Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 832 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 833 Expect(frames).To(HaveLen(1)) 834 Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 835 ccf := frames[0].(*logging.ConnectionCloseFrame) 836 Expect(ccf.IsApplicationError).To(BeFalse()) 837 Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) 838 }) 839 done := make(chan struct{}) 840 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 841 defer close(done) 842 checkConnectionCloseError(b, hdr, qerr.InvalidToken) 843 return len(b), nil 844 }) 845 phm.EXPECT().Get(gomock.Any()) 846 serv.handlePacket(packet) 847 Eventually(done).Should(BeClosed()) 848 }) 849 850 It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() { 851 serv.verifySourceAddress = func(net.Addr) bool { return true } 852 token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}) 853 Expect(err).ToNot(HaveOccurred()) 854 hdr := &wire.Header{ 855 Type: protocol.PacketTypeInitial, 856 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 857 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 858 Token: token, 859 Version: protocol.Version1, 860 } 861 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 862 packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet 863 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 864 packet.remoteAddr = raddr 865 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) 866 done := make(chan struct{}) 867 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 868 defer close(done) 869 replyHdr := parseHeader(b) 870 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 871 return len(b), nil 872 }) 873 phm.EXPECT().Get(gomock.Any()) 874 serv.handlePacket(packet) 875 // make sure there are no Write calls on the packet conn 876 Eventually(done).Should(BeClosed()) 877 }) 878 879 It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() { 880 serv.verifySourceAddress = func(net.Addr) bool { return true } 881 serv.maxTokenAge = time.Millisecond 882 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 883 token, err := serv.tokenGenerator.NewToken(raddr) 884 Expect(err).ToNot(HaveOccurred()) 885 time.Sleep(2 * time.Millisecond) // make sure the token is expired 886 hdr := &wire.Header{ 887 Type: protocol.PacketTypeInitial, 888 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 889 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 890 Token: token, 891 Version: protocol.Version1, 892 } 893 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 894 packet.remoteAddr = raddr 895 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 896 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 897 }) 898 done := make(chan struct{}) 899 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 900 defer close(done) 901 return len(b), nil 902 }) 903 phm.EXPECT().Get(gomock.Any()) 904 serv.handlePacket(packet) 905 Eventually(done).Should(BeClosed()) 906 }) 907 908 It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { 909 token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) 910 Expect(err).ToNot(HaveOccurred()) 911 hdr := &wire.Header{ 912 Type: protocol.PacketTypeInitial, 913 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 914 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 915 Token: token, 916 Version: protocol.Version1, 917 } 918 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 919 packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet 920 packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 921 done := make(chan struct{}) 922 tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) 923 phm.EXPECT().Get(gomock.Any()) 924 serv.handlePacket(packet) 925 // make sure there are no Write calls on the packet conn 926 time.Sleep(50 * time.Millisecond) 927 Eventually(done).Should(BeClosed()) 928 }) 929 }) 930 931 Context("accepting connections", func() { 932 It("returns Accept when closed", func() { 933 done := make(chan struct{}) 934 go func() { 935 defer GinkgoRecover() 936 _, err := serv.Accept(context.Background()) 937 Expect(err).To(MatchError(ErrServerClosed)) 938 close(done) 939 }() 940 941 serv.Close() 942 Eventually(done).Should(BeClosed()) 943 }) 944 945 It("returns immediately, if an error occurred before", func() { 946 serv.Close() 947 for i := 0; i < 3; i++ { 948 _, err := serv.Accept(context.Background()) 949 Expect(err).To(MatchError(ErrServerClosed)) 950 } 951 }) 952 953 PIt("closes connection that are still handshaking after Close", func() { 954 serv.Close() 955 956 destroyed := make(chan struct{}) 957 serv.newConn = func( 958 _ sendConn, 959 _ connRunner, 960 _ protocol.ConnectionID, 961 _ *protocol.ConnectionID, 962 _ protocol.ConnectionID, 963 _ protocol.ConnectionID, 964 _ protocol.ConnectionID, 965 _ ConnectionIDGenerator, 966 _ protocol.StatelessResetToken, 967 conf *Config, 968 _ *tls.Config, 969 _ *handshake.TokenGenerator, 970 _ bool, 971 _ *logging.ConnectionTracer, 972 _ uint64, 973 _ utils.Logger, 974 _ protocol.Version, 975 ) quicConn { 976 conn := NewMockQUICConn(mockCtrl) 977 conn.EXPECT().handlePacket(gomock.Any()) 978 conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) }) 979 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 980 conn.EXPECT().run().MaxTimes(1) 981 conn.EXPECT().Context().Return(context.Background()) 982 return conn 983 } 984 phm.EXPECT().Get(gomock.Any()) 985 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 986 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 987 serv.handleInitialImpl( 988 receivedPacket{buffer: getPacketBuffer()}, 989 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 990 ) 991 Eventually(destroyed).Should(BeClosed()) 992 }) 993 994 It("returns when the context is canceled", func() { 995 ctx, cancel := context.WithCancel(context.Background()) 996 done := make(chan struct{}) 997 go func() { 998 defer GinkgoRecover() 999 _, err := serv.Accept(ctx) 1000 Expect(err).To(MatchError("context canceled")) 1001 close(done) 1002 }() 1003 1004 Consistently(done).ShouldNot(BeClosed()) 1005 cancel() 1006 Eventually(done).Should(BeClosed()) 1007 }) 1008 1009 It("uses the config returned by GetConfigClient", func() { 1010 conn := NewMockQUICConn(mockCtrl) 1011 1012 conf := &Config{MaxIncomingStreams: 1234} 1013 serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }}) 1014 done := make(chan struct{}) 1015 go func() { 1016 defer GinkgoRecover() 1017 s, err := serv.Accept(context.Background()) 1018 Expect(err).ToNot(HaveOccurred()) 1019 Expect(s).To(Equal(conn)) 1020 close(done) 1021 }() 1022 1023 handshakeChan := make(chan struct{}) 1024 serv.newConn = func( 1025 _ sendConn, 1026 _ connRunner, 1027 _ protocol.ConnectionID, 1028 _ *protocol.ConnectionID, 1029 _ protocol.ConnectionID, 1030 _ protocol.ConnectionID, 1031 _ protocol.ConnectionID, 1032 _ ConnectionIDGenerator, 1033 _ protocol.StatelessResetToken, 1034 conf *Config, 1035 _ *tls.Config, 1036 _ *handshake.TokenGenerator, 1037 _ bool, 1038 _ *logging.ConnectionTracer, 1039 _ uint64, 1040 _ utils.Logger, 1041 _ protocol.Version, 1042 ) quicConn { 1043 Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234)) 1044 conn.EXPECT().handlePacket(gomock.Any()) 1045 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 1046 conn.EXPECT().run() 1047 conn.EXPECT().Context().Return(context.Background()) 1048 return conn 1049 } 1050 phm.EXPECT().Get(gomock.Any()) 1051 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1052 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1053 serv.handleInitialImpl( 1054 receivedPacket{buffer: getPacketBuffer()}, 1055 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1056 ) 1057 Consistently(done).ShouldNot(BeClosed()) 1058 close(handshakeChan) // complete the handshake 1059 Eventually(done).Should(BeClosed()) 1060 }) 1061 1062 It("rejects a connection attempt when GetConfigClient returns an error", func() { 1063 serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }}) 1064 1065 phm.EXPECT().Get(gomock.Any()) 1066 done := make(chan struct{}) 1067 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) 1068 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 1069 defer close(done) 1070 rejectHdr := parseHeader(b) 1071 Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) 1072 return len(b), nil 1073 }) 1074 serv.handleInitialImpl( 1075 receivedPacket{buffer: getPacketBuffer()}, 1076 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1}, 1077 ) 1078 Eventually(done).Should(BeClosed()) 1079 }) 1080 1081 It("accepts new connections when the handshake completes", func() { 1082 conn := NewMockQUICConn(mockCtrl) 1083 1084 done := make(chan struct{}) 1085 go func() { 1086 defer GinkgoRecover() 1087 s, err := serv.Accept(context.Background()) 1088 Expect(err).ToNot(HaveOccurred()) 1089 Expect(s).To(Equal(conn)) 1090 close(done) 1091 }() 1092 1093 handshakeChan := make(chan struct{}) 1094 serv.newConn = func( 1095 _ sendConn, 1096 runner connRunner, 1097 _ protocol.ConnectionID, 1098 _ *protocol.ConnectionID, 1099 _ protocol.ConnectionID, 1100 _ protocol.ConnectionID, 1101 _ protocol.ConnectionID, 1102 _ ConnectionIDGenerator, 1103 _ protocol.StatelessResetToken, 1104 _ *Config, 1105 _ *tls.Config, 1106 _ *handshake.TokenGenerator, 1107 _ bool, 1108 _ *logging.ConnectionTracer, 1109 _ uint64, 1110 _ utils.Logger, 1111 _ protocol.Version, 1112 ) quicConn { 1113 conn.EXPECT().handlePacket(gomock.Any()) 1114 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 1115 conn.EXPECT().run() 1116 conn.EXPECT().Context().Return(context.Background()) 1117 return conn 1118 } 1119 phm.EXPECT().Get(gomock.Any()) 1120 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1121 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1122 serv.handleInitialImpl( 1123 receivedPacket{buffer: getPacketBuffer()}, 1124 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1125 ) 1126 Consistently(done).ShouldNot(BeClosed()) 1127 close(handshakeChan) // complete the handshake 1128 Eventually(done).Should(BeClosed()) 1129 }) 1130 }) 1131 }) 1132 1133 Context("server accepting connections that haven't completed the handshake", func() { 1134 var ( 1135 serv *EarlyListener 1136 phm *MockPacketHandlerManager 1137 ) 1138 1139 BeforeEach(func() { 1140 var err error 1141 serv, err = ListenEarly(conn, tlsConf, nil) 1142 Expect(err).ToNot(HaveOccurred()) 1143 phm = NewMockPacketHandlerManager(mockCtrl) 1144 serv.baseServer.connHandler = phm 1145 }) 1146 1147 AfterEach(func() { 1148 serv.Close() 1149 }) 1150 1151 It("accepts new connections when they become ready", func() { 1152 conn := NewMockQUICConn(mockCtrl) 1153 1154 done := make(chan struct{}) 1155 go func() { 1156 defer GinkgoRecover() 1157 s, err := serv.Accept(context.Background()) 1158 Expect(err).ToNot(HaveOccurred()) 1159 Expect(s).To(Equal(conn)) 1160 close(done) 1161 }() 1162 1163 ready := make(chan struct{}) 1164 serv.baseServer.newConn = func( 1165 _ sendConn, 1166 runner connRunner, 1167 _ protocol.ConnectionID, 1168 _ *protocol.ConnectionID, 1169 _ protocol.ConnectionID, 1170 _ protocol.ConnectionID, 1171 _ protocol.ConnectionID, 1172 _ ConnectionIDGenerator, 1173 _ protocol.StatelessResetToken, 1174 _ *Config, 1175 _ *tls.Config, 1176 _ *handshake.TokenGenerator, 1177 _ bool, 1178 _ *logging.ConnectionTracer, 1179 _ uint64, 1180 _ utils.Logger, 1181 _ protocol.Version, 1182 ) quicConn { 1183 conn.EXPECT().handlePacket(gomock.Any()) 1184 conn.EXPECT().run() 1185 conn.EXPECT().earlyConnReady().Return(ready) 1186 conn.EXPECT().Context().Return(context.Background()) 1187 return conn 1188 } 1189 phm.EXPECT().Get(gomock.Any()) 1190 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1191 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1192 serv.baseServer.handleInitialImpl( 1193 receivedPacket{buffer: getPacketBuffer()}, 1194 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1195 ) 1196 Consistently(done).ShouldNot(BeClosed()) 1197 close(ready) 1198 Eventually(done).Should(BeClosed()) 1199 }) 1200 1201 It("rejects new connection attempts if the accept queue is full", func() { 1202 connChan := make(chan *MockQUICConn, 1) 1203 var wg sync.WaitGroup // to make sure the test fully completes 1204 wg.Add(protocol.MaxAcceptQueueSize + 1) 1205 serv.baseServer.newConn = func( 1206 _ sendConn, 1207 runner connRunner, 1208 _ protocol.ConnectionID, 1209 _ *protocol.ConnectionID, 1210 _ protocol.ConnectionID, 1211 _ protocol.ConnectionID, 1212 _ protocol.ConnectionID, 1213 _ ConnectionIDGenerator, 1214 _ protocol.StatelessResetToken, 1215 _ *Config, 1216 _ *tls.Config, 1217 _ *handshake.TokenGenerator, 1218 _ bool, 1219 _ *logging.ConnectionTracer, 1220 _ uint64, 1221 _ utils.Logger, 1222 _ protocol.Version, 1223 ) quicConn { 1224 defer wg.Done() 1225 ready := make(chan struct{}) 1226 close(ready) 1227 conn := <-connChan 1228 conn.EXPECT().handlePacket(gomock.Any()) 1229 conn.EXPECT().run() 1230 conn.EXPECT().earlyConnReady().Return(ready) 1231 conn.EXPECT().Context().Return(context.Background()) 1232 return conn 1233 } 1234 1235 phm.EXPECT().Get(gomock.Any()).AnyTimes() 1236 phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) 1237 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize) 1238 for i := 0; i < protocol.MaxAcceptQueueSize; i++ { 1239 conn := NewMockQUICConn(mockCtrl) 1240 connChan <- conn 1241 serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) 1242 } 1243 1244 Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize)) 1245 1246 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1247 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1248 conn := NewMockQUICConn(mockCtrl) 1249 conn.EXPECT().closeWithTransportError(ConnectionRefused) 1250 connChan <- conn 1251 serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) 1252 wg.Wait() 1253 }) 1254 1255 It("doesn't accept new connections if they were closed in the mean time", func() { 1256 p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) 1257 ctx, cancel := context.WithCancel(context.Background()) 1258 connCreated := make(chan struct{}) 1259 conn := NewMockQUICConn(mockCtrl) 1260 serv.baseServer.newConn = func( 1261 _ sendConn, 1262 runner connRunner, 1263 _ protocol.ConnectionID, 1264 _ *protocol.ConnectionID, 1265 _ protocol.ConnectionID, 1266 _ protocol.ConnectionID, 1267 _ protocol.ConnectionID, 1268 _ ConnectionIDGenerator, 1269 _ protocol.StatelessResetToken, 1270 _ *Config, 1271 _ *tls.Config, 1272 _ *handshake.TokenGenerator, 1273 _ bool, 1274 _ *logging.ConnectionTracer, 1275 _ uint64, 1276 _ utils.Logger, 1277 _ protocol.Version, 1278 ) quicConn { 1279 conn.EXPECT().handlePacket(p) 1280 conn.EXPECT().run() 1281 conn.EXPECT().earlyConnReady() 1282 conn.EXPECT().Context().Return(ctx) 1283 close(connCreated) 1284 return conn 1285 } 1286 1287 phm.EXPECT().Get(gomock.Any()) 1288 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1289 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1290 serv.baseServer.handlePacket(p) 1291 // make sure there are no Write calls on the packet conn 1292 time.Sleep(50 * time.Millisecond) 1293 Eventually(connCreated).Should(BeClosed()) 1294 cancel() 1295 time.Sleep(scaleDuration(200 * time.Millisecond)) 1296 1297 done := make(chan struct{}) 1298 go func() { 1299 defer GinkgoRecover() 1300 serv.Accept(context.Background()) 1301 close(done) 1302 }() 1303 Consistently(done).ShouldNot(BeClosed()) 1304 1305 // make the go routine return 1306 Expect(serv.Close()).To(Succeed()) 1307 Eventually(done).Should(BeClosed()) 1308 }) 1309 }) 1310 1311 Context("0-RTT", func() { 1312 var ( 1313 tr *Transport 1314 serv *baseServer 1315 phm *MockPacketHandlerManager 1316 tracer *mocklogging.MockTracer 1317 ) 1318 1319 BeforeEach(func() { 1320 var t *logging.Tracer 1321 t, tracer = mocklogging.NewMockTracer(mockCtrl) 1322 tr = &Transport{Conn: conn, Tracer: t} 1323 ln, err := tr.ListenEarly(tlsConf, nil) 1324 Expect(err).ToNot(HaveOccurred()) 1325 phm = NewMockPacketHandlerManager(mockCtrl) 1326 serv = ln.baseServer 1327 serv.connHandler = phm 1328 }) 1329 1330 AfterEach(func() { 1331 tracer.EXPECT().Close() 1332 Expect(tr.Close()).To(Succeed()) 1333 }) 1334 1335 It("passes packets to existing connections", func() { 1336 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1337 p := getPacket(&wire.Header{ 1338 Type: protocol.PacketType0RTT, 1339 DestConnectionID: connID, 1340 Version: serv.config.Versions[0], 1341 }, make([]byte, 100)) 1342 conn := NewMockPacketHandler(mockCtrl) 1343 phm.EXPECT().Get(connID).Return(conn, true) 1344 handled := make(chan struct{}) 1345 conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) 1346 serv.handlePacket(p) 1347 Eventually(handled).Should(BeClosed()) 1348 }) 1349 1350 It("queues 0-RTT packets, up to Max0RTTQueueSize", func() { 1351 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1352 1353 var zeroRTTPackets []receivedPacket 1354 1355 for i := 0; i < protocol.Max0RTTQueueLen; i++ { 1356 p := getPacket(&wire.Header{ 1357 Type: protocol.PacketType0RTT, 1358 DestConnectionID: connID, 1359 Version: serv.config.Versions[0], 1360 }, make([]byte, 100+i)) 1361 phm.EXPECT().Get(connID) 1362 serv.handlePacket(p) 1363 zeroRTTPackets = append(zeroRTTPackets, p) 1364 } 1365 1366 // send one more packet, this one should be dropped 1367 p := getPacket(&wire.Header{ 1368 Type: protocol.PacketType0RTT, 1369 DestConnectionID: connID, 1370 Version: serv.config.Versions[0], 1371 }, make([]byte, 200)) 1372 phm.EXPECT().Get(connID) 1373 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) 1374 serv.handlePacket(p) 1375 1376 initial := getPacket(&wire.Header{ 1377 Type: protocol.PacketTypeInitial, 1378 DestConnectionID: connID, 1379 Version: serv.config.Versions[0], 1380 }, make([]byte, protocol.MinInitialPacketSize)) 1381 called := make(chan struct{}) 1382 serv.newConn = func( 1383 _ sendConn, 1384 _ connRunner, 1385 _ protocol.ConnectionID, 1386 _ *protocol.ConnectionID, 1387 _ protocol.ConnectionID, 1388 _ protocol.ConnectionID, 1389 _ protocol.ConnectionID, 1390 _ ConnectionIDGenerator, 1391 _ protocol.StatelessResetToken, 1392 _ *Config, 1393 _ *tls.Config, 1394 _ *handshake.TokenGenerator, 1395 _ bool, 1396 _ *logging.ConnectionTracer, 1397 _ uint64, 1398 _ utils.Logger, 1399 _ protocol.Version, 1400 ) quicConn { 1401 conn := NewMockQUICConn(mockCtrl) 1402 var calls []any 1403 calls = append(calls, conn.EXPECT().handlePacket(initial)) 1404 for _, p := range zeroRTTPackets { 1405 calls = append(calls, conn.EXPECT().handlePacket(p)) 1406 } 1407 gomock.InOrder(calls...) 1408 conn.EXPECT().run() 1409 conn.EXPECT().earlyConnReady() 1410 conn.EXPECT().Context().Return(context.Background()) 1411 close(called) 1412 // shutdown 1413 conn.EXPECT().closeWithTransportError(gomock.Any()) 1414 return conn 1415 } 1416 1417 phm.EXPECT().Get(connID) 1418 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1419 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1420 serv.handlePacket(initial) 1421 Eventually(called).Should(BeClosed()) 1422 }) 1423 1424 It("limits the number of queues", func() { 1425 for i := 0; i < protocol.Max0RTTQueues; i++ { 1426 b := make([]byte, 16) 1427 rand.Read(b) 1428 connID := protocol.ParseConnectionID(b) 1429 p := getPacket(&wire.Header{ 1430 Type: protocol.PacketType0RTT, 1431 DestConnectionID: connID, 1432 Version: serv.config.Versions[0], 1433 }, make([]byte, 100+i)) 1434 phm.EXPECT().Get(connID) 1435 serv.handlePacket(p) 1436 } 1437 1438 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1439 p := getPacket(&wire.Header{ 1440 Type: protocol.PacketType0RTT, 1441 DestConnectionID: connID, 1442 Version: serv.config.Versions[0], 1443 }, make([]byte, 200)) 1444 phm.EXPECT().Get(connID) 1445 dropped := make(chan struct{}) 1446 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 1447 close(dropped) 1448 }) 1449 serv.handlePacket(p) 1450 Eventually(dropped).Should(BeClosed()) 1451 }) 1452 1453 It("drops queues after a while", func() { 1454 now := time.Now() 1455 1456 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1457 p := getPacket(&wire.Header{ 1458 Type: protocol.PacketType0RTT, 1459 DestConnectionID: connID, 1460 Version: serv.config.Versions[0], 1461 }, make([]byte, 200)) 1462 p.rcvTime = now 1463 1464 connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9}) 1465 p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2) 1466 p2 := getPacket(&wire.Header{ 1467 Type: protocol.PacketType0RTT, 1468 DestConnectionID: connID2, 1469 Version: serv.config.Versions[0], 1470 }, make([]byte, 300)) 1471 p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet 1472 1473 dropped1 := make(chan struct{}) 1474 dropped2 := make(chan struct{}) 1475 // need to register the call before handling the packet to avoid race condition 1476 gomock.InOrder( 1477 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 1478 close(dropped1) 1479 }), 1480 tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 1481 close(dropped2) 1482 }), 1483 ) 1484 1485 phm.EXPECT().Get(connID) 1486 serv.handlePacket(p) 1487 1488 // There's no cleanup Go routine. 1489 // Cleanup is triggered when new packets are received. 1490 1491 phm.EXPECT().Get(connID2) 1492 serv.handlePacket(p2) 1493 // make sure no cleanup is executed 1494 Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed()) 1495 1496 // There's no cleanup Go routine. 1497 // Cleanup is triggered when new packets are received. 1498 connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0}) 1499 p3 := getPacket(&wire.Header{ 1500 Type: protocol.PacketType0RTT, 1501 DestConnectionID: connID3, 1502 Version: serv.config.Versions[0], 1503 }, make([]byte, 200)) 1504 p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup 1505 phm.EXPECT().Get(connID3) 1506 serv.handlePacket(p3) 1507 Eventually(dropped1).Should(BeClosed()) 1508 Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed()) 1509 1510 // make sure the second packet is also cleaned up 1511 connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1}) 1512 p4 := getPacket(&wire.Header{ 1513 Type: protocol.PacketType0RTT, 1514 DestConnectionID: connID4, 1515 Version: serv.config.Versions[0], 1516 }, make([]byte, 200)) 1517 p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup 1518 phm.EXPECT().Get(connID4) 1519 serv.handlePacket(p4) 1520 Eventually(dropped2).Should(BeClosed()) 1521 }) 1522 }) 1523 })