github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/server_test.go (about) 1 package quic 2 3 import ( 4 "context" 5 "crypto/rand" 6 "crypto/tls" 7 "errors" 8 "net" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "github.com/danielpfeifer02/quic-go-prio-packs/internal/handshake" 14 mocklogging "github.com/danielpfeifer02/quic-go-prio-packs/internal/mocks/logging" 15 "github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol" 16 "github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr" 17 "github.com/danielpfeifer02/quic-go-prio-packs/internal/testdata" 18 "github.com/danielpfeifer02/quic-go-prio-packs/internal/utils" 19 "github.com/danielpfeifer02/quic-go-prio-packs/internal/wire" 20 "github.com/danielpfeifer02/quic-go-prio-packs/logging" 21 22 . "github.com/onsi/ginkgo/v2" 23 . "github.com/onsi/gomega" 24 "go.uber.org/mock/gomock" 25 ) 26 27 var _ = Describe("Server", func() { 28 var ( 29 conn *MockPacketConn 30 tlsConf *tls.Config 31 ) 32 33 getPacket := func(hdr *wire.Header, p []byte) receivedPacket { 34 buf := getPacketBuffer() 35 hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 36 var err error 37 buf.Data, err = (&wire.ExtendedHeader{ 38 Header: *hdr, 39 PacketNumber: 0x42, 40 PacketNumberLen: protocol.PacketNumberLen4, 41 }).Append(buf.Data, protocol.Version1) 42 Expect(err).ToNot(HaveOccurred()) 43 n := len(buf.Data) 44 buf.Data = append(buf.Data, p...) 45 data := buf.Data 46 sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) 47 _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) 48 data = data[:len(data)+16] 49 sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n]) 50 return receivedPacket{ 51 remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, 52 data: data, 53 buffer: buf, 54 } 55 } 56 57 getInitial := func(destConnID protocol.ConnectionID) receivedPacket { 58 senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} 59 hdr := &wire.Header{ 60 Type: protocol.PacketTypeInitial, 61 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 62 DestConnectionID: destConnID, 63 Version: protocol.Version1, 64 } 65 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 66 p.buffer = getPacketBuffer() 67 p.remoteAddr = senderAddr 68 return p 69 } 70 71 getInitialWithRandomDestConnID := func() receivedPacket { 72 b := make([]byte, 10) 73 _, err := rand.Read(b) 74 Expect(err).ToNot(HaveOccurred()) 75 76 return getInitial(protocol.ParseConnectionID(b)) 77 } 78 79 parseHeader := func(data []byte) *wire.Header { 80 hdr, _, _, err := wire.ParsePacket(data) 81 Expect(err).ToNot(HaveOccurred()) 82 return hdr 83 } 84 85 checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) { 86 replyHdr := parseHeader(b) 87 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 88 Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID)) 89 Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID)) 90 _, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) 91 extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version) 92 Expect(err).ToNot(HaveOccurred()) 93 data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) 94 Expect(err).ToNot(HaveOccurred()) 95 _, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version) 96 Expect(err).ToNot(HaveOccurred()) 97 Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 98 ccf := f.(*wire.ConnectionCloseFrame) 99 Expect(ccf.IsApplicationError).To(BeFalse()) 100 Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode)) 101 Expect(ccf.ReasonPhrase).To(BeEmpty()) 102 } 103 104 BeforeEach(func() { 105 conn = NewMockPacketConn(mockCtrl) 106 conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() 107 wait := make(chan struct{}) 108 conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) { 109 <-wait 110 return 0, nil, errors.New("done") 111 }).MaxTimes(1) 112 conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) error { 113 close(wait) 114 conn.EXPECT().SetReadDeadline(time.Time{}) 115 return nil 116 }).MaxTimes(1) 117 tlsConf = testdata.GetTLSConfig() 118 tlsConf.NextProtos = []string{"proto1"} 119 }) 120 121 It("errors when no tls.Config is given", func() { 122 _, err := ListenAddr("localhost:0", nil, nil) 123 Expect(err).To(HaveOccurred()) 124 Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set")) 125 }) 126 127 It("errors when the Config contains an invalid version", func() { 128 version := protocol.Version(0x1234) 129 _, err := Listen(nil, tlsConf, &Config{Versions: []protocol.Version{version}}) 130 Expect(err).To(MatchError("invalid QUIC version: 0x1234")) 131 }) 132 133 It("fills in default values if options are not set in the Config", func() { 134 ln, err := Listen(conn, tlsConf, &Config{}) 135 Expect(err).ToNot(HaveOccurred()) 136 server := ln.baseServer 137 Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) 138 Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) 139 Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) 140 Expect(server.config.KeepAlivePeriod).To(BeZero()) 141 // stop the listener 142 Expect(ln.Close()).To(Succeed()) 143 }) 144 145 It("setups with the right values", func() { 146 supportedVersions := []protocol.Version{protocol.Version1} 147 config := Config{ 148 Versions: supportedVersions, 149 HandshakeIdleTimeout: 1337 * time.Hour, 150 MaxIdleTimeout: 42 * time.Minute, 151 KeepAlivePeriod: 5 * time.Second, 152 } 153 ln, err := Listen(conn, tlsConf, &config) 154 Expect(err).ToNot(HaveOccurred()) 155 server := ln.baseServer 156 Expect(server.connHandler).ToNot(BeNil()) 157 Expect(server.config.Versions).To(Equal(supportedVersions)) 158 Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) 159 Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) 160 Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) 161 // stop the listener 162 Expect(ln.Close()).To(Succeed()) 163 }) 164 165 It("listens on a given address", func() { 166 addr := "127.0.0.1:13579" 167 ln, err := ListenAddr(addr, tlsConf, &Config{}) 168 Expect(err).ToNot(HaveOccurred()) 169 Expect(ln.Addr().String()).To(Equal(addr)) 170 // stop the listener 171 Expect(ln.Close()).To(Succeed()) 172 }) 173 174 It("errors if given an invalid address", func() { 175 addr := "127.0.0.1" 176 _, err := ListenAddr(addr, tlsConf, &Config{}) 177 Expect(err).To(BeAssignableToTypeOf(&net.AddrError{})) 178 }) 179 180 It("errors if given an invalid address", func() { 181 addr := "1.1.1.1:1111" 182 _, err := ListenAddr(addr, tlsConf, &Config{}) 183 Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) 184 }) 185 186 Context("server accepting connections that completed the handshake", func() { 187 var ( 188 tr *Transport 189 serv *baseServer 190 phm *MockPacketHandlerManager 191 tracer *mocklogging.MockTracer 192 ) 193 194 BeforeEach(func() { 195 var t *logging.Tracer 196 t, tracer = mocklogging.NewMockTracer(mockCtrl) 197 tr = &Transport{Conn: conn, Tracer: t} 198 ln, err := tr.Listen(tlsConf, nil) 199 Expect(err).ToNot(HaveOccurred()) 200 serv = ln.baseServer 201 phm = NewMockPacketHandlerManager(mockCtrl) 202 serv.connHandler = phm 203 }) 204 205 AfterEach(func() { 206 tracer.EXPECT().Close() 207 tr.Close() 208 }) 209 210 Context("handling packets", func() { 211 It("drops Initial packets with a too short connection ID", func() { 212 p := getPacket(&wire.Header{ 213 Type: protocol.PacketTypeInitial, 214 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), 215 Version: serv.config.Versions[0], 216 }, nil) 217 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) 218 serv.handlePacket(p) 219 // make sure there are no Write calls on the packet conn 220 time.Sleep(50 * time.Millisecond) 221 }) 222 223 It("drops too small Initial", func() { 224 p := getPacket(&wire.Header{ 225 Type: protocol.PacketTypeInitial, 226 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), 227 Version: serv.config.Versions[0], 228 }, make([]byte, protocol.MinInitialPacketSize-100)) 229 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) 230 serv.handlePacket(p) 231 // make sure there are no Write calls on the packet conn 232 time.Sleep(50 * time.Millisecond) 233 }) 234 235 It("drops non-Initial packets", func() { 236 p := getPacket(&wire.Header{ 237 Type: protocol.PacketTypeHandshake, 238 Version: serv.config.Versions[0], 239 }, []byte("invalid")) 240 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket) 241 serv.handlePacket(p) 242 // make sure there are no Write calls on the packet conn 243 time.Sleep(50 * time.Millisecond) 244 }) 245 246 It("passes packets to existing connections", func() { 247 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 248 p := getPacket(&wire.Header{ 249 Type: protocol.PacketTypeInitial, 250 DestConnectionID: connID, 251 Version: serv.config.Versions[0], 252 }, make([]byte, protocol.MinInitialPacketSize)) 253 conn := NewMockPacketHandler(mockCtrl) 254 phm.EXPECT().Get(connID).Return(conn, true) 255 handled := make(chan struct{}) 256 conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) 257 serv.handlePacket(p) 258 Eventually(handled).Should(BeClosed()) 259 }) 260 261 It("creates a connection when the token is accepted", func() { 262 serv.maxNumHandshakesUnvalidated = 0 263 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 264 retryToken, err := serv.tokenGenerator.NewRetryToken( 265 raddr, 266 protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), 267 protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), 268 ) 269 Expect(err).ToNot(HaveOccurred()) 270 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 271 hdr := &wire.Header{ 272 Type: protocol.PacketTypeInitial, 273 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 274 DestConnectionID: connID, 275 Version: protocol.Version1, 276 Token: retryToken, 277 } 278 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 279 p.remoteAddr = raddr 280 run := make(chan struct{}) 281 var token protocol.StatelessResetToken 282 rand.Read(token[:]) 283 284 var newConnID protocol.ConnectionID 285 conn := NewMockQUICConn(mockCtrl) 286 serv.newConn = func( 287 _ sendConn, 288 _ connRunner, 289 origDestConnID protocol.ConnectionID, 290 retrySrcConnID *protocol.ConnectionID, 291 clientDestConnID protocol.ConnectionID, 292 destConnID protocol.ConnectionID, 293 srcConnID protocol.ConnectionID, 294 _ ConnectionIDGenerator, 295 tokenP protocol.StatelessResetToken, 296 _ *Config, 297 _ *tls.Config, 298 _ *handshake.TokenGenerator, 299 _ bool, 300 _ *logging.ConnectionTracer, 301 _ uint64, 302 _ utils.Logger, 303 _ protocol.Version, 304 ) quicConn { 305 Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))) 306 Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) 307 Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) 308 Expect(destConnID).To(Equal(hdr.SrcConnectionID)) 309 // make sure we're using a server-generated connection ID 310 Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) 311 Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) 312 newConnID = srcConnID 313 Expect(tokenP).To(Equal(token)) 314 conn.EXPECT().handlePacket(p) 315 conn.EXPECT().run().Do(func() error { close(run); return nil }) 316 conn.EXPECT().Context().Return(context.Background()) 317 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 318 return conn 319 } 320 phm.EXPECT().Get(connID) 321 phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token) 322 phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool { 323 Expect(cid).To(Equal(newConnID)) 324 return true 325 }) 326 327 done := make(chan struct{}) 328 go func() { 329 defer GinkgoRecover() 330 serv.handlePacket(p) 331 // the Handshake packet is written by the connection. 332 // Make sure there are no Write calls on the packet conn. 333 time.Sleep(50 * time.Millisecond) 334 close(done) 335 }() 336 // make sure we're using a server-generated connection ID 337 Eventually(run).Should(BeClosed()) 338 Eventually(done).Should(BeClosed()) 339 // shutdown 340 conn.EXPECT().closeWithTransportError(gomock.Any()) 341 }) 342 343 It("sends a Version Negotiation Packet for unsupported versions", func() { 344 srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) 345 destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) 346 packet := getPacket(&wire.Header{ 347 Type: protocol.PacketTypeHandshake, 348 SrcConnectionID: srcConnID, 349 DestConnectionID: destConnID, 350 Version: 0x42, 351 }, make([]byte, protocol.MinUnknownVersionPacketSize)) 352 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 353 packet.remoteAddr = raddr 354 tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.Version) { 355 Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) 356 Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) 357 }) 358 done := make(chan struct{}) 359 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 360 defer close(done) 361 Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) 362 dest, src, versions, err := wire.ParseVersionNegotiationPacket(b) 363 Expect(err).ToNot(HaveOccurred()) 364 Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) 365 Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) 366 Expect(versions).ToNot(ContainElement(protocol.Version(0x42))) 367 return len(b), nil 368 }) 369 serv.handlePacket(packet) 370 Eventually(done).Should(BeClosed()) 371 }) 372 373 It("doesn't send a Version Negotiation packets if sending them is disabled", func() { 374 serv.disableVersionNegotiation = true 375 srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) 376 destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) 377 packet := getPacket(&wire.Header{ 378 Type: protocol.PacketTypeHandshake, 379 SrcConnectionID: srcConnID, 380 DestConnectionID: destConnID, 381 Version: 0x42, 382 }, make([]byte, protocol.MinUnknownVersionPacketSize)) 383 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 384 packet.remoteAddr = raddr 385 done := make(chan struct{}) 386 serv.handlePacket(packet) 387 Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed()) 388 }) 389 390 It("ignores Version Negotiation packets", func() { 391 data := wire.ComposeVersionNegotiation( 392 protocol.ArbitraryLenConnectionID{1, 2, 3, 4}, 393 protocol.ArbitraryLenConnectionID{4, 3, 2, 1}, 394 []protocol.Version{1, 2, 3}, 395 ) 396 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 397 done := make(chan struct{}) 398 tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 399 close(done) 400 }) 401 serv.handlePacket(receivedPacket{ 402 remoteAddr: raddr, 403 data: data, 404 buffer: getPacketBuffer(), 405 }) 406 Eventually(done).Should(BeClosed()) 407 // make sure no other packet is sent 408 time.Sleep(scaleDuration(20 * time.Millisecond)) 409 }) 410 411 It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() { 412 srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) 413 destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) 414 p := getPacket(&wire.Header{ 415 Type: protocol.PacketTypeHandshake, 416 SrcConnectionID: srcConnID, 417 DestConnectionID: destConnID, 418 Version: 0x42, 419 }, make([]byte, protocol.MinUnknownVersionPacketSize-50)) 420 Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize)) 421 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 422 p.remoteAddr = raddr 423 done := make(chan struct{}) 424 tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 425 close(done) 426 }) 427 serv.handlePacket(p) 428 Eventually(done).Should(BeClosed()) 429 // make sure no other packet is sent 430 time.Sleep(scaleDuration(20 * time.Millisecond)) 431 }) 432 433 It("replies with a Retry packet, if a token is required", func() { 434 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 435 serv.maxNumHandshakesUnvalidated = 0 436 hdr := &wire.Header{ 437 Type: protocol.PacketTypeInitial, 438 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 439 DestConnectionID: connID, 440 Version: protocol.Version1, 441 } 442 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 443 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 444 packet.remoteAddr = raddr 445 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { 446 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 447 Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) 448 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 449 Expect(replyHdr.Token).ToNot(BeEmpty()) 450 }) 451 done := make(chan struct{}) 452 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 453 defer close(done) 454 replyHdr := parseHeader(b) 455 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 456 Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) 457 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 458 Expect(replyHdr.Token).ToNot(BeEmpty()) 459 Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:])) 460 return len(b), nil 461 }) 462 phm.EXPECT().Get(connID) 463 serv.handlePacket(packet) 464 Eventually(done).Should(BeClosed()) 465 }) 466 467 It("creates a connection, if no token is required", func() { 468 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 469 hdr := &wire.Header{ 470 Type: protocol.PacketTypeInitial, 471 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 472 DestConnectionID: connID, 473 Version: protocol.Version1, 474 } 475 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 476 run := make(chan struct{}) 477 var token protocol.StatelessResetToken 478 rand.Read(token[:]) 479 480 var newConnID protocol.ConnectionID 481 conn := NewMockQUICConn(mockCtrl) 482 serv.newConn = func( 483 _ sendConn, 484 _ connRunner, 485 origDestConnID protocol.ConnectionID, 486 retrySrcConnID *protocol.ConnectionID, 487 clientDestConnID protocol.ConnectionID, 488 destConnID protocol.ConnectionID, 489 srcConnID protocol.ConnectionID, 490 _ ConnectionIDGenerator, 491 tokenP protocol.StatelessResetToken, 492 _ *Config, 493 _ *tls.Config, 494 _ *handshake.TokenGenerator, 495 _ bool, 496 _ *logging.ConnectionTracer, 497 _ uint64, 498 _ utils.Logger, 499 _ protocol.Version, 500 ) quicConn { 501 Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) 502 Expect(retrySrcConnID).To(BeNil()) 503 Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) 504 Expect(destConnID).To(Equal(hdr.SrcConnectionID)) 505 // make sure we're using a server-generated connection ID 506 Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) 507 Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) 508 newConnID = srcConnID 509 Expect(tokenP).To(Equal(token)) 510 conn.EXPECT().handlePacket(p) 511 conn.EXPECT().run().Do(func() error { close(run); return nil }) 512 conn.EXPECT().Context().Return(context.Background()) 513 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 514 return conn 515 } 516 gomock.InOrder( 517 phm.EXPECT().Get(connID), 518 phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token), 519 phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool { 520 Expect(c).To(Equal(newConnID)) 521 return true 522 }), 523 ) 524 525 done := make(chan struct{}) 526 go func() { 527 defer GinkgoRecover() 528 serv.handlePacket(p) 529 // the Handshake packet is written by the connection 530 // make sure there are no Write calls on the packet conn 531 time.Sleep(50 * time.Millisecond) 532 close(done) 533 }() 534 // make sure we're using a server-generated connection ID 535 Eventually(run).Should(BeClosed()) 536 Eventually(done).Should(BeClosed()) 537 // shutdown 538 conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) 539 }) 540 541 It("drops packets if the receive queue is full", func() { 542 serv.maxNumHandshakesTotal = 10000 543 serv.maxNumHandshakesUnvalidated = 10000 544 545 phm.EXPECT().Get(gomock.Any()).AnyTimes() 546 phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() 547 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() 548 549 acceptConn := make(chan struct{}) 550 var counter atomic.Uint32 551 serv.newConn = func( 552 _ sendConn, 553 runner connRunner, 554 _ protocol.ConnectionID, 555 _ *protocol.ConnectionID, 556 _ protocol.ConnectionID, 557 _ protocol.ConnectionID, 558 _ protocol.ConnectionID, 559 _ ConnectionIDGenerator, 560 _ protocol.StatelessResetToken, 561 _ *Config, 562 _ *tls.Config, 563 _ *handshake.TokenGenerator, 564 _ bool, 565 _ *logging.ConnectionTracer, 566 _ uint64, 567 _ utils.Logger, 568 _ protocol.Version, 569 ) quicConn { 570 <-acceptConn 571 counter.Add(1) 572 conn := NewMockQUICConn(mockCtrl) 573 conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) 574 conn.EXPECT().run().MaxTimes(1) 575 conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) 576 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1) 577 // shutdown 578 conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) 579 return conn 580 } 581 582 p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})) 583 serv.handlePacket(p) 584 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1) 585 var wg sync.WaitGroup 586 for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ { 587 wg.Add(1) 588 go func() { 589 defer GinkgoRecover() 590 defer wg.Done() 591 serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))) 592 }() 593 } 594 wg.Wait() 595 596 close(acceptConn) 597 Eventually( 598 func() uint32 { return counter.Load() }, 599 scaleDuration(100*time.Millisecond), 600 ).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) 601 Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) 602 }) 603 604 It("only creates a single connection for a duplicate Initial", func() { 605 done := make(chan struct{}) 606 serv.newConn = func( 607 _ sendConn, 608 runner connRunner, 609 _ protocol.ConnectionID, 610 _ *protocol.ConnectionID, 611 _ protocol.ConnectionID, 612 _ protocol.ConnectionID, 613 _ protocol.ConnectionID, 614 _ ConnectionIDGenerator, 615 _ protocol.StatelessResetToken, 616 _ *Config, 617 _ *tls.Config, 618 _ *handshake.TokenGenerator, 619 _ bool, 620 _ *logging.ConnectionTracer, 621 _ uint64, 622 _ utils.Logger, 623 _ protocol.Version, 624 ) quicConn { 625 conn := NewMockQUICConn(mockCtrl) 626 conn.EXPECT().handlePacket(gomock.Any()) 627 conn.EXPECT().closeWithTransportError(qerr.ConnectionRefused).Do(func(qerr.TransportErrorCode) { 628 close(done) 629 }) 630 return conn 631 } 632 633 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) 634 p := getInitial(connID) 635 phm.EXPECT().Get(connID) 636 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 637 phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision 638 Expect(serv.handlePacketImpl(p)).To(BeTrue()) 639 Eventually(done).Should(BeClosed()) 640 }) 641 642 It("limits the number of unvalidated handshakes", func() { 643 const limit = 3 644 serv.maxNumHandshakesTotal = 10000 645 serv.maxNumHandshakesUnvalidated = limit 646 647 phm.EXPECT().Get(gomock.Any()).AnyTimes() 648 phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() 649 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() 650 651 handshakeChan := make(chan struct{}) 652 connChan := make(chan *MockQUICConn, 1) 653 var wg sync.WaitGroup 654 wg.Add(2 * limit) 655 serv.newConn = func( 656 _ sendConn, 657 runner connRunner, 658 _ protocol.ConnectionID, 659 _ *protocol.ConnectionID, 660 _ protocol.ConnectionID, 661 _ protocol.ConnectionID, 662 _ protocol.ConnectionID, 663 _ ConnectionIDGenerator, 664 _ protocol.StatelessResetToken, 665 _ *Config, 666 _ *tls.Config, 667 _ *handshake.TokenGenerator, 668 _ bool, 669 _ *logging.ConnectionTracer, 670 _ uint64, 671 _ utils.Logger, 672 _ protocol.Version, 673 ) quicConn { 674 conn := <-connChan 675 conn.EXPECT().handlePacket(gomock.Any()) 676 conn.EXPECT().run() 677 conn.EXPECT().Context().Return(context.Background()) 678 conn.EXPECT().HandshakeComplete().Return(handshakeChan).Do(func() <-chan struct{} { wg.Done(); return nil }) 679 return conn 680 } 681 682 // Initiate the maximum number of allowed connection attempts. 683 for i := 0; i < limit; i++ { 684 conn := NewMockQUICConn(mockCtrl) 685 connChan <- conn 686 serv.handlePacket(getInitialWithRandomDestConnID()) 687 } 688 689 // Now initiate another connection attempt. 690 p := getInitialWithRandomDestConnID() 691 done := make(chan struct{}) 692 tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 693 defer GinkgoRecover() 694 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 695 }) 696 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 697 defer GinkgoRecover() 698 defer close(done) 699 hdr, _, _, err := wire.ParsePacket(b) 700 Expect(err).ToNot(HaveOccurred()) 701 Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) 702 return len(b), nil 703 }) 704 serv.handlePacket(p) 705 Eventually(done).Should(BeClosed()) 706 707 close(handshakeChan) 708 for i := 0; i < limit; i++ { 709 _, err := serv.Accept(context.Background()) 710 Expect(err).ToNot(HaveOccurred()) 711 } 712 for i := 0; i < limit; i++ { 713 conn := NewMockQUICConn(mockCtrl) 714 conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed 715 connChan <- conn 716 serv.handlePacket(getInitialWithRandomDestConnID()) 717 } 718 wg.Wait() 719 }) 720 721 It("limits the number of total handshakes", func() { 722 const limit = 3 723 serv.maxNumHandshakesTotal = limit 724 serv.maxNumHandshakesUnvalidated = limit // same limit, but we check that we send CONNECTION_REFUSED and not Retry 725 726 phm.EXPECT().Get(gomock.Any()).AnyTimes() 727 phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() 728 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() 729 730 handshakeChan := make(chan struct{}) 731 connChan := make(chan *MockQUICConn, 1) 732 serv.newConn = func( 733 _ sendConn, 734 runner connRunner, 735 _ protocol.ConnectionID, 736 _ *protocol.ConnectionID, 737 _ protocol.ConnectionID, 738 _ protocol.ConnectionID, 739 _ protocol.ConnectionID, 740 _ ConnectionIDGenerator, 741 _ protocol.StatelessResetToken, 742 _ *Config, 743 _ *tls.Config, 744 _ *handshake.TokenGenerator, 745 _ bool, 746 _ *logging.ConnectionTracer, 747 _ uint64, 748 _ utils.Logger, 749 _ protocol.Version, 750 ) quicConn { 751 conn := <-connChan 752 conn.EXPECT().handlePacket(gomock.Any()) 753 conn.EXPECT().run() 754 conn.EXPECT().Context().Return(context.Background()) 755 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 756 return conn 757 } 758 759 for i := 0; i < limit; i++ { 760 conn := NewMockQUICConn(mockCtrl) 761 connChan <- conn 762 serv.handlePacket(getInitialWithRandomDestConnID()) 763 } 764 765 p := getInitialWithRandomDestConnID() 766 done := make(chan struct{}) 767 tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 768 defer GinkgoRecover() 769 hdr, _, _, err := wire.ParsePacket(p.data) 770 Expect(err).ToNot(HaveOccurred()) 771 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 772 Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 773 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 774 Expect(frames).To(HaveLen(1)) 775 Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 776 ccf := frames[0].(*logging.ConnectionCloseFrame) 777 Expect(ccf.IsApplicationError).To(BeFalse()) 778 Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ConnectionRefused)) 779 }) 780 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 781 defer GinkgoRecover() 782 defer close(done) 783 hdr, _, _, err := wire.ParsePacket(p.data) 784 Expect(err).ToNot(HaveOccurred()) 785 checkConnectionCloseError(b, hdr, qerr.ConnectionRefused) 786 return len(b), nil 787 }) 788 serv.handlePacket(p) 789 Eventually(done).Should(BeClosed()) 790 791 close(handshakeChan) 792 for i := 0; i < limit; i++ { 793 _, err := serv.Accept(context.Background()) 794 Expect(err).ToNot(HaveOccurred()) 795 } 796 // make sure we can enqueue and accept more connections after that 797 for i := 0; i < limit; i++ { 798 conn := NewMockQUICConn(mockCtrl) 799 conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed 800 connChan <- conn 801 serv.handlePacket(getInitialWithRandomDestConnID()) 802 } 803 for i := 0; i < limit; i++ { 804 _, err := serv.Accept(context.Background()) 805 Expect(err).ToNot(HaveOccurred()) 806 } 807 }) 808 }) 809 810 Context("token validation", func() { 811 It("decodes the token from the token field", func() { 812 raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} 813 token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) 814 Expect(err).ToNot(HaveOccurred()) 815 packet := getPacket(&wire.Header{ 816 Type: protocol.PacketTypeInitial, 817 Token: token, 818 Version: serv.config.Versions[0], 819 }, make([]byte, protocol.MinInitialPacketSize)) 820 packet.remoteAddr = raddr 821 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) 822 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) 823 824 done := make(chan struct{}) 825 phm.EXPECT().Get(gomock.Any()) 826 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 827 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool { 828 close(done) 829 return true 830 }) 831 phm.EXPECT().Remove(gomock.Any()).AnyTimes() 832 serv.handlePacket(packet) 833 Eventually(done).Should(BeClosed()) 834 }) 835 836 It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { 837 serv.maxNumHandshakesUnvalidated = 0 838 token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) 839 Expect(err).ToNot(HaveOccurred()) 840 hdr := &wire.Header{ 841 Type: protocol.PacketTypeInitial, 842 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 843 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 844 Token: token, 845 Version: protocol.Version1, 846 } 847 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 848 packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet 849 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 850 packet.remoteAddr = raddr 851 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 852 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 853 Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 854 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 855 Expect(frames).To(HaveLen(1)) 856 Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 857 ccf := frames[0].(*logging.ConnectionCloseFrame) 858 Expect(ccf.IsApplicationError).To(BeFalse()) 859 Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) 860 }) 861 done := make(chan struct{}) 862 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 863 defer close(done) 864 checkConnectionCloseError(b, hdr, qerr.InvalidToken) 865 return len(b), nil 866 }) 867 phm.EXPECT().Get(gomock.Any()) 868 serv.handlePacket(packet) 869 Eventually(done).Should(BeClosed()) 870 }) 871 872 It("sends an INVALID_TOKEN error, if an expired retry token is received", func() { 873 serv.maxNumHandshakesUnvalidated = 0 874 serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout 875 Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond)) 876 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 877 token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) 878 Expect(err).ToNot(HaveOccurred()) 879 time.Sleep(2 * time.Millisecond) // make sure the token is expired 880 hdr := &wire.Header{ 881 Type: protocol.PacketTypeInitial, 882 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 883 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 884 Token: token, 885 Version: protocol.Version1, 886 } 887 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 888 packet.remoteAddr = raddr 889 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 890 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 891 Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 892 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 893 Expect(frames).To(HaveLen(1)) 894 Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 895 ccf := frames[0].(*logging.ConnectionCloseFrame) 896 Expect(ccf.IsApplicationError).To(BeFalse()) 897 Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) 898 }) 899 done := make(chan struct{}) 900 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 901 defer close(done) 902 checkConnectionCloseError(b, hdr, qerr.InvalidToken) 903 return len(b), nil 904 }) 905 phm.EXPECT().Get(gomock.Any()) 906 serv.handlePacket(packet) 907 Eventually(done).Should(BeClosed()) 908 }) 909 910 It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() { 911 serv.maxNumHandshakesUnvalidated = 0 912 token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}) 913 Expect(err).ToNot(HaveOccurred()) 914 hdr := &wire.Header{ 915 Type: protocol.PacketTypeInitial, 916 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 917 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 918 Token: token, 919 Version: protocol.Version1, 920 } 921 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 922 packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet 923 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 924 packet.remoteAddr = raddr 925 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) 926 done := make(chan struct{}) 927 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 928 defer close(done) 929 replyHdr := parseHeader(b) 930 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 931 return len(b), nil 932 }) 933 phm.EXPECT().Get(gomock.Any()) 934 serv.handlePacket(packet) 935 // make sure there are no Write calls on the packet conn 936 Eventually(done).Should(BeClosed()) 937 }) 938 939 It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() { 940 serv.maxNumHandshakesUnvalidated = 0 941 serv.maxTokenAge = time.Millisecond 942 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 943 token, err := serv.tokenGenerator.NewToken(raddr) 944 Expect(err).ToNot(HaveOccurred()) 945 time.Sleep(2 * time.Millisecond) // make sure the token is expired 946 hdr := &wire.Header{ 947 Type: protocol.PacketTypeInitial, 948 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 949 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 950 Token: token, 951 Version: protocol.Version1, 952 } 953 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 954 packet.remoteAddr = raddr 955 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 956 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 957 }) 958 done := make(chan struct{}) 959 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 960 defer close(done) 961 return len(b), nil 962 }) 963 phm.EXPECT().Get(gomock.Any()) 964 serv.handlePacket(packet) 965 Eventually(done).Should(BeClosed()) 966 }) 967 968 It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { 969 serv.maxNumHandshakesUnvalidated = 0 970 token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) 971 Expect(err).ToNot(HaveOccurred()) 972 hdr := &wire.Header{ 973 Type: protocol.PacketTypeInitial, 974 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 975 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 976 Token: token, 977 Version: protocol.Version1, 978 } 979 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 980 packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet 981 packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 982 done := make(chan struct{}) 983 tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) 984 phm.EXPECT().Get(gomock.Any()) 985 serv.handlePacket(packet) 986 // make sure there are no Write calls on the packet conn 987 time.Sleep(50 * time.Millisecond) 988 Eventually(done).Should(BeClosed()) 989 }) 990 }) 991 992 Context("accepting connections", func() { 993 It("returns Accept when closed", func() { 994 done := make(chan struct{}) 995 go func() { 996 defer GinkgoRecover() 997 _, err := serv.Accept(context.Background()) 998 Expect(err).To(MatchError(ErrServerClosed)) 999 close(done) 1000 }() 1001 1002 serv.Close() 1003 Eventually(done).Should(BeClosed()) 1004 }) 1005 1006 It("returns immediately, if an error occurred before", func() { 1007 serv.Close() 1008 for i := 0; i < 3; i++ { 1009 _, err := serv.Accept(context.Background()) 1010 Expect(err).To(MatchError(ErrServerClosed)) 1011 } 1012 }) 1013 1014 It("closes connection that are still handshaking after Close", func() { 1015 serv.Close() 1016 1017 destroyed := make(chan struct{}) 1018 serv.newConn = func( 1019 _ sendConn, 1020 _ connRunner, 1021 _ protocol.ConnectionID, 1022 _ *protocol.ConnectionID, 1023 _ protocol.ConnectionID, 1024 _ protocol.ConnectionID, 1025 _ protocol.ConnectionID, 1026 _ ConnectionIDGenerator, 1027 _ protocol.StatelessResetToken, 1028 conf *Config, 1029 _ *tls.Config, 1030 _ *handshake.TokenGenerator, 1031 _ bool, 1032 _ *logging.ConnectionTracer, 1033 _ uint64, 1034 _ utils.Logger, 1035 _ protocol.Version, 1036 ) quicConn { 1037 conn := NewMockQUICConn(mockCtrl) 1038 conn.EXPECT().handlePacket(gomock.Any()) 1039 conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) }) 1040 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 1041 conn.EXPECT().run().MaxTimes(1) 1042 conn.EXPECT().Context().Return(context.Background()) 1043 return conn 1044 } 1045 phm.EXPECT().Get(gomock.Any()) 1046 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1047 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1048 serv.handleInitialImpl( 1049 receivedPacket{buffer: getPacketBuffer()}, 1050 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1051 ) 1052 Eventually(destroyed).Should(BeClosed()) 1053 }) 1054 1055 It("returns when the context is canceled", func() { 1056 ctx, cancel := context.WithCancel(context.Background()) 1057 done := make(chan struct{}) 1058 go func() { 1059 defer GinkgoRecover() 1060 _, err := serv.Accept(ctx) 1061 Expect(err).To(MatchError("context canceled")) 1062 close(done) 1063 }() 1064 1065 Consistently(done).ShouldNot(BeClosed()) 1066 cancel() 1067 Eventually(done).Should(BeClosed()) 1068 }) 1069 1070 It("uses the config returned by GetConfigClient", func() { 1071 conn := NewMockQUICConn(mockCtrl) 1072 1073 conf := &Config{MaxIncomingStreams: 1234} 1074 serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }}) 1075 done := make(chan struct{}) 1076 go func() { 1077 defer GinkgoRecover() 1078 s, err := serv.Accept(context.Background()) 1079 Expect(err).ToNot(HaveOccurred()) 1080 Expect(s).To(Equal(conn)) 1081 close(done) 1082 }() 1083 1084 handshakeChan := make(chan struct{}) 1085 serv.newConn = func( 1086 _ sendConn, 1087 _ connRunner, 1088 _ protocol.ConnectionID, 1089 _ *protocol.ConnectionID, 1090 _ protocol.ConnectionID, 1091 _ protocol.ConnectionID, 1092 _ protocol.ConnectionID, 1093 _ ConnectionIDGenerator, 1094 _ protocol.StatelessResetToken, 1095 conf *Config, 1096 _ *tls.Config, 1097 _ *handshake.TokenGenerator, 1098 _ bool, 1099 _ *logging.ConnectionTracer, 1100 _ uint64, 1101 _ utils.Logger, 1102 _ protocol.Version, 1103 ) quicConn { 1104 Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234)) 1105 conn.EXPECT().handlePacket(gomock.Any()) 1106 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 1107 conn.EXPECT().run() 1108 conn.EXPECT().Context().Return(context.Background()) 1109 return conn 1110 } 1111 phm.EXPECT().Get(gomock.Any()) 1112 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1113 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1114 serv.handleInitialImpl( 1115 receivedPacket{buffer: getPacketBuffer()}, 1116 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1117 ) 1118 Consistently(done).ShouldNot(BeClosed()) 1119 close(handshakeChan) // complete the handshake 1120 Eventually(done).Should(BeClosed()) 1121 }) 1122 1123 It("rejects a connection attempt when GetConfigClient returns an error", func() { 1124 serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }}) 1125 1126 phm.EXPECT().Get(gomock.Any()) 1127 done := make(chan struct{}) 1128 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) 1129 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 1130 defer close(done) 1131 rejectHdr := parseHeader(b) 1132 Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) 1133 return len(b), nil 1134 }) 1135 serv.handleInitialImpl( 1136 receivedPacket{buffer: getPacketBuffer()}, 1137 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1}, 1138 ) 1139 Eventually(done).Should(BeClosed()) 1140 }) 1141 1142 It("accepts new connections when the handshake completes", func() { 1143 conn := NewMockQUICConn(mockCtrl) 1144 1145 done := make(chan struct{}) 1146 go func() { 1147 defer GinkgoRecover() 1148 s, err := serv.Accept(context.Background()) 1149 Expect(err).ToNot(HaveOccurred()) 1150 Expect(s).To(Equal(conn)) 1151 close(done) 1152 }() 1153 1154 handshakeChan := make(chan struct{}) 1155 serv.newConn = func( 1156 _ sendConn, 1157 runner connRunner, 1158 _ protocol.ConnectionID, 1159 _ *protocol.ConnectionID, 1160 _ protocol.ConnectionID, 1161 _ protocol.ConnectionID, 1162 _ protocol.ConnectionID, 1163 _ ConnectionIDGenerator, 1164 _ protocol.StatelessResetToken, 1165 _ *Config, 1166 _ *tls.Config, 1167 _ *handshake.TokenGenerator, 1168 _ bool, 1169 _ *logging.ConnectionTracer, 1170 _ uint64, 1171 _ utils.Logger, 1172 _ protocol.Version, 1173 ) quicConn { 1174 conn.EXPECT().handlePacket(gomock.Any()) 1175 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 1176 conn.EXPECT().run() 1177 conn.EXPECT().Context().Return(context.Background()) 1178 return conn 1179 } 1180 phm.EXPECT().Get(gomock.Any()) 1181 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1182 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1183 serv.handleInitialImpl( 1184 receivedPacket{buffer: getPacketBuffer()}, 1185 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1186 ) 1187 Consistently(done).ShouldNot(BeClosed()) 1188 close(handshakeChan) // complete the handshake 1189 Eventually(done).Should(BeClosed()) 1190 }) 1191 }) 1192 }) 1193 1194 Context("server accepting connections that haven't completed the handshake", func() { 1195 var ( 1196 serv *EarlyListener 1197 phm *MockPacketHandlerManager 1198 ) 1199 1200 BeforeEach(func() { 1201 var err error 1202 serv, err = ListenEarly(conn, tlsConf, nil) 1203 Expect(err).ToNot(HaveOccurred()) 1204 phm = NewMockPacketHandlerManager(mockCtrl) 1205 serv.baseServer.connHandler = phm 1206 }) 1207 1208 AfterEach(func() { 1209 serv.Close() 1210 }) 1211 1212 It("accepts new connections when they become ready", func() { 1213 conn := NewMockQUICConn(mockCtrl) 1214 1215 done := make(chan struct{}) 1216 go func() { 1217 defer GinkgoRecover() 1218 s, err := serv.Accept(context.Background()) 1219 Expect(err).ToNot(HaveOccurred()) 1220 Expect(s).To(Equal(conn)) 1221 close(done) 1222 }() 1223 1224 ready := make(chan struct{}) 1225 serv.baseServer.newConn = func( 1226 _ sendConn, 1227 runner connRunner, 1228 _ protocol.ConnectionID, 1229 _ *protocol.ConnectionID, 1230 _ protocol.ConnectionID, 1231 _ protocol.ConnectionID, 1232 _ protocol.ConnectionID, 1233 _ ConnectionIDGenerator, 1234 _ protocol.StatelessResetToken, 1235 _ *Config, 1236 _ *tls.Config, 1237 _ *handshake.TokenGenerator, 1238 _ bool, 1239 _ *logging.ConnectionTracer, 1240 _ uint64, 1241 _ utils.Logger, 1242 _ protocol.Version, 1243 ) quicConn { 1244 conn.EXPECT().handlePacket(gomock.Any()) 1245 conn.EXPECT().run() 1246 conn.EXPECT().earlyConnReady().Return(ready) 1247 conn.EXPECT().Context().Return(context.Background()) 1248 return conn 1249 } 1250 phm.EXPECT().Get(gomock.Any()) 1251 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1252 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1253 serv.baseServer.handleInitialImpl( 1254 receivedPacket{buffer: getPacketBuffer()}, 1255 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1256 ) 1257 Consistently(done).ShouldNot(BeClosed()) 1258 close(ready) 1259 Eventually(done).Should(BeClosed()) 1260 }) 1261 1262 It("rejects new connection attempts if the accept queue is full", func() { 1263 connChan := make(chan *MockQUICConn, 1) 1264 var wg sync.WaitGroup // to make sure the test fully completes 1265 wg.Add(protocol.MaxAcceptQueueSize + 1) 1266 serv.baseServer.newConn = func( 1267 _ sendConn, 1268 runner connRunner, 1269 _ protocol.ConnectionID, 1270 _ *protocol.ConnectionID, 1271 _ protocol.ConnectionID, 1272 _ protocol.ConnectionID, 1273 _ protocol.ConnectionID, 1274 _ ConnectionIDGenerator, 1275 _ protocol.StatelessResetToken, 1276 _ *Config, 1277 _ *tls.Config, 1278 _ *handshake.TokenGenerator, 1279 _ bool, 1280 _ *logging.ConnectionTracer, 1281 _ uint64, 1282 _ utils.Logger, 1283 _ protocol.Version, 1284 ) quicConn { 1285 defer wg.Done() 1286 ready := make(chan struct{}) 1287 close(ready) 1288 conn := <-connChan 1289 conn.EXPECT().handlePacket(gomock.Any()) 1290 conn.EXPECT().run() 1291 conn.EXPECT().earlyConnReady().Return(ready) 1292 conn.EXPECT().Context().Return(context.Background()) 1293 return conn 1294 } 1295 1296 phm.EXPECT().Get(gomock.Any()).AnyTimes() 1297 phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) 1298 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize) 1299 for i := 0; i < protocol.MaxAcceptQueueSize; i++ { 1300 conn := NewMockQUICConn(mockCtrl) 1301 connChan <- conn 1302 serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) 1303 } 1304 1305 Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize)) 1306 1307 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1308 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1309 conn := NewMockQUICConn(mockCtrl) 1310 conn.EXPECT().closeWithTransportError(ConnectionRefused) 1311 connChan <- conn 1312 serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) 1313 wg.Wait() 1314 }) 1315 1316 It("doesn't accept new connections if they were closed in the mean time", func() { 1317 p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) 1318 ctx, cancel := context.WithCancel(context.Background()) 1319 connCreated := make(chan struct{}) 1320 conn := NewMockQUICConn(mockCtrl) 1321 serv.baseServer.newConn = func( 1322 _ sendConn, 1323 runner connRunner, 1324 _ protocol.ConnectionID, 1325 _ *protocol.ConnectionID, 1326 _ protocol.ConnectionID, 1327 _ protocol.ConnectionID, 1328 _ protocol.ConnectionID, 1329 _ ConnectionIDGenerator, 1330 _ protocol.StatelessResetToken, 1331 _ *Config, 1332 _ *tls.Config, 1333 _ *handshake.TokenGenerator, 1334 _ bool, 1335 _ *logging.ConnectionTracer, 1336 _ uint64, 1337 _ utils.Logger, 1338 _ protocol.Version, 1339 ) quicConn { 1340 conn.EXPECT().handlePacket(p) 1341 conn.EXPECT().run() 1342 conn.EXPECT().earlyConnReady() 1343 conn.EXPECT().Context().Return(ctx) 1344 close(connCreated) 1345 return conn 1346 } 1347 1348 phm.EXPECT().Get(gomock.Any()) 1349 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1350 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1351 serv.baseServer.handlePacket(p) 1352 // make sure there are no Write calls on the packet conn 1353 time.Sleep(50 * time.Millisecond) 1354 Eventually(connCreated).Should(BeClosed()) 1355 cancel() 1356 time.Sleep(scaleDuration(200 * time.Millisecond)) 1357 1358 done := make(chan struct{}) 1359 go func() { 1360 defer GinkgoRecover() 1361 serv.Accept(context.Background()) 1362 close(done) 1363 }() 1364 Consistently(done).ShouldNot(BeClosed()) 1365 1366 // make the go routine return 1367 Expect(serv.Close()).To(Succeed()) 1368 Eventually(done).Should(BeClosed()) 1369 }) 1370 }) 1371 1372 Context("0-RTT", func() { 1373 var ( 1374 tr *Transport 1375 serv *baseServer 1376 phm *MockPacketHandlerManager 1377 tracer *mocklogging.MockTracer 1378 ) 1379 1380 BeforeEach(func() { 1381 var t *logging.Tracer 1382 t, tracer = mocklogging.NewMockTracer(mockCtrl) 1383 tr = &Transport{Conn: conn, Tracer: t} 1384 ln, err := tr.ListenEarly(tlsConf, nil) 1385 Expect(err).ToNot(HaveOccurred()) 1386 phm = NewMockPacketHandlerManager(mockCtrl) 1387 serv = ln.baseServer 1388 serv.connHandler = phm 1389 }) 1390 1391 AfterEach(func() { 1392 tracer.EXPECT().Close() 1393 Expect(tr.Close()).To(Succeed()) 1394 }) 1395 1396 It("passes packets to existing connections", func() { 1397 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1398 p := getPacket(&wire.Header{ 1399 Type: protocol.PacketType0RTT, 1400 DestConnectionID: connID, 1401 Version: serv.config.Versions[0], 1402 }, make([]byte, 100)) 1403 conn := NewMockPacketHandler(mockCtrl) 1404 phm.EXPECT().Get(connID).Return(conn, true) 1405 handled := make(chan struct{}) 1406 conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) 1407 serv.handlePacket(p) 1408 Eventually(handled).Should(BeClosed()) 1409 }) 1410 1411 It("queues 0-RTT packets, up to Max0RTTQueueSize", func() { 1412 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1413 1414 var zeroRTTPackets []receivedPacket 1415 1416 for i := 0; i < protocol.Max0RTTQueueLen; i++ { 1417 p := getPacket(&wire.Header{ 1418 Type: protocol.PacketType0RTT, 1419 DestConnectionID: connID, 1420 Version: serv.config.Versions[0], 1421 }, make([]byte, 100+i)) 1422 phm.EXPECT().Get(connID) 1423 serv.handlePacket(p) 1424 zeroRTTPackets = append(zeroRTTPackets, p) 1425 } 1426 1427 // send one more packet, this one should be dropped 1428 p := getPacket(&wire.Header{ 1429 Type: protocol.PacketType0RTT, 1430 DestConnectionID: connID, 1431 Version: serv.config.Versions[0], 1432 }, make([]byte, 200)) 1433 phm.EXPECT().Get(connID) 1434 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) 1435 serv.handlePacket(p) 1436 1437 initial := getPacket(&wire.Header{ 1438 Type: protocol.PacketTypeInitial, 1439 DestConnectionID: connID, 1440 Version: serv.config.Versions[0], 1441 }, make([]byte, protocol.MinInitialPacketSize)) 1442 called := make(chan struct{}) 1443 serv.newConn = func( 1444 _ sendConn, 1445 _ connRunner, 1446 _ protocol.ConnectionID, 1447 _ *protocol.ConnectionID, 1448 _ protocol.ConnectionID, 1449 _ protocol.ConnectionID, 1450 _ protocol.ConnectionID, 1451 _ ConnectionIDGenerator, 1452 _ protocol.StatelessResetToken, 1453 _ *Config, 1454 _ *tls.Config, 1455 _ *handshake.TokenGenerator, 1456 _ bool, 1457 _ *logging.ConnectionTracer, 1458 _ uint64, 1459 _ utils.Logger, 1460 _ protocol.Version, 1461 ) quicConn { 1462 conn := NewMockQUICConn(mockCtrl) 1463 var calls []any 1464 calls = append(calls, conn.EXPECT().handlePacket(initial)) 1465 for _, p := range zeroRTTPackets { 1466 calls = append(calls, conn.EXPECT().handlePacket(p)) 1467 } 1468 gomock.InOrder(calls...) 1469 conn.EXPECT().run() 1470 conn.EXPECT().earlyConnReady() 1471 conn.EXPECT().Context().Return(context.Background()) 1472 close(called) 1473 // shutdown 1474 conn.EXPECT().closeWithTransportError(gomock.Any()) 1475 return conn 1476 } 1477 1478 phm.EXPECT().Get(connID) 1479 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1480 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1481 serv.handlePacket(initial) 1482 Eventually(called).Should(BeClosed()) 1483 }) 1484 1485 It("limits the number of queues", func() { 1486 for i := 0; i < protocol.Max0RTTQueues; i++ { 1487 b := make([]byte, 16) 1488 rand.Read(b) 1489 connID := protocol.ParseConnectionID(b) 1490 p := getPacket(&wire.Header{ 1491 Type: protocol.PacketType0RTT, 1492 DestConnectionID: connID, 1493 Version: serv.config.Versions[0], 1494 }, make([]byte, 100+i)) 1495 phm.EXPECT().Get(connID) 1496 serv.handlePacket(p) 1497 } 1498 1499 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1500 p := getPacket(&wire.Header{ 1501 Type: protocol.PacketType0RTT, 1502 DestConnectionID: connID, 1503 Version: serv.config.Versions[0], 1504 }, make([]byte, 200)) 1505 phm.EXPECT().Get(connID) 1506 dropped := make(chan struct{}) 1507 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 1508 close(dropped) 1509 }) 1510 serv.handlePacket(p) 1511 Eventually(dropped).Should(BeClosed()) 1512 }) 1513 1514 It("drops queues after a while", func() { 1515 now := time.Now() 1516 1517 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1518 p := getPacket(&wire.Header{ 1519 Type: protocol.PacketType0RTT, 1520 DestConnectionID: connID, 1521 Version: serv.config.Versions[0], 1522 }, make([]byte, 200)) 1523 p.rcvTime = now 1524 1525 connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9}) 1526 p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2) 1527 p2 := getPacket(&wire.Header{ 1528 Type: protocol.PacketType0RTT, 1529 DestConnectionID: connID2, 1530 Version: serv.config.Versions[0], 1531 }, make([]byte, 300)) 1532 p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet 1533 1534 dropped1 := make(chan struct{}) 1535 dropped2 := make(chan struct{}) 1536 // need to register the call before handling the packet to avoid race condition 1537 gomock.InOrder( 1538 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 1539 close(dropped1) 1540 }), 1541 tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 1542 close(dropped2) 1543 }), 1544 ) 1545 1546 phm.EXPECT().Get(connID) 1547 serv.handlePacket(p) 1548 1549 // There's no cleanup Go routine. 1550 // Cleanup is triggered when new packets are received. 1551 1552 phm.EXPECT().Get(connID2) 1553 serv.handlePacket(p2) 1554 // make sure no cleanup is executed 1555 Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed()) 1556 1557 // There's no cleanup Go routine. 1558 // Cleanup is triggered when new packets are received. 1559 connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0}) 1560 p3 := getPacket(&wire.Header{ 1561 Type: protocol.PacketType0RTT, 1562 DestConnectionID: connID3, 1563 Version: serv.config.Versions[0], 1564 }, make([]byte, 200)) 1565 p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup 1566 phm.EXPECT().Get(connID3) 1567 serv.handlePacket(p3) 1568 Eventually(dropped1).Should(BeClosed()) 1569 Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed()) 1570 1571 // make sure the second packet is also cleaned up 1572 connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1}) 1573 p4 := getPacket(&wire.Header{ 1574 Type: protocol.PacketType0RTT, 1575 DestConnectionID: connID4, 1576 Version: serv.config.Versions[0], 1577 }, make([]byte, 200)) 1578 p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup 1579 phm.EXPECT().Get(connID4) 1580 serv.handlePacket(p4) 1581 Eventually(dropped2).Should(BeClosed()) 1582 }) 1583 }) 1584 })