github.com/TugasAkhir-QUIC/quic-go@v0.0.2-0.20240215011318-d20e25a9054c/server_test.go (about) 1 package quic 2 3 import ( 4 "context" 5 "crypto/rand" 6 "crypto/tls" 7 "errors" 8 "net" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "github.com/TugasAkhir-QUIC/quic-go/internal/handshake" 14 mocklogging "github.com/TugasAkhir-QUIC/quic-go/internal/mocks/logging" 15 "github.com/TugasAkhir-QUIC/quic-go/internal/protocol" 16 "github.com/TugasAkhir-QUIC/quic-go/internal/qerr" 17 "github.com/TugasAkhir-QUIC/quic-go/internal/testdata" 18 "github.com/TugasAkhir-QUIC/quic-go/internal/utils" 19 "github.com/TugasAkhir-QUIC/quic-go/internal/wire" 20 "github.com/TugasAkhir-QUIC/quic-go/logging" 21 22 . "github.com/onsi/ginkgo/v2" 23 . "github.com/onsi/gomega" 24 "go.uber.org/mock/gomock" 25 ) 26 27 var _ = Describe("Server", func() { 28 var ( 29 conn *MockPacketConn 30 tlsConf *tls.Config 31 ) 32 33 getPacket := func(hdr *wire.Header, p []byte) receivedPacket { 34 buf := getPacketBuffer() 35 hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 36 var err error 37 buf.Data, err = (&wire.ExtendedHeader{ 38 Header: *hdr, 39 PacketNumber: 0x42, 40 PacketNumberLen: protocol.PacketNumberLen4, 41 }).Append(buf.Data, protocol.Version1) 42 Expect(err).ToNot(HaveOccurred()) 43 n := len(buf.Data) 44 buf.Data = append(buf.Data, p...) 45 data := buf.Data 46 sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) 47 _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) 48 data = data[:len(data)+16] 49 sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n]) 50 return receivedPacket{ 51 remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, 52 data: data, 53 buffer: buf, 54 } 55 } 56 57 getInitial := func(destConnID protocol.ConnectionID) receivedPacket { 58 senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} 59 hdr := &wire.Header{ 60 Type: protocol.PacketTypeInitial, 61 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 62 DestConnectionID: destConnID, 63 Version: protocol.Version1, 64 } 65 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 66 p.buffer = getPacketBuffer() 67 p.remoteAddr = senderAddr 68 return p 69 } 70 71 getInitialWithRandomDestConnID := func() receivedPacket { 72 b := make([]byte, 10) 73 _, err := rand.Read(b) 74 Expect(err).ToNot(HaveOccurred()) 75 76 return getInitial(protocol.ParseConnectionID(b)) 77 } 78 79 parseHeader := func(data []byte) *wire.Header { 80 hdr, _, _, err := wire.ParsePacket(data) 81 Expect(err).ToNot(HaveOccurred()) 82 return hdr 83 } 84 85 checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) { 86 replyHdr := parseHeader(b) 87 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 88 Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID)) 89 Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID)) 90 _, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) 91 extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version) 92 Expect(err).ToNot(HaveOccurred()) 93 data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) 94 Expect(err).ToNot(HaveOccurred()) 95 _, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version) 96 Expect(err).ToNot(HaveOccurred()) 97 Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 98 ccf := f.(*wire.ConnectionCloseFrame) 99 Expect(ccf.IsApplicationError).To(BeFalse()) 100 Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode)) 101 Expect(ccf.ReasonPhrase).To(BeEmpty()) 102 } 103 104 BeforeEach(func() { 105 conn = NewMockPacketConn(mockCtrl) 106 conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() 107 wait := make(chan struct{}) 108 conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) { 109 <-wait 110 return 0, nil, errors.New("done") 111 }).MaxTimes(1) 112 conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) error { 113 close(wait) 114 conn.EXPECT().SetReadDeadline(time.Time{}) 115 return nil 116 }).MaxTimes(1) 117 tlsConf = testdata.GetTLSConfig() 118 tlsConf.NextProtos = []string{"proto1"} 119 }) 120 121 It("errors when no tls.Config is given", func() { 122 _, err := ListenAddr("localhost:0", nil, nil) 123 Expect(err).To(HaveOccurred()) 124 Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set")) 125 }) 126 127 It("errors when the Config contains an invalid version", func() { 128 version := protocol.Version(0x1234) 129 _, err := Listen(nil, tlsConf, &Config{Versions: []protocol.Version{version}}) 130 Expect(err).To(MatchError("invalid QUIC version: 0x1234")) 131 }) 132 133 It("fills in default values if options are not set in the Config", func() { 134 ln, err := Listen(conn, tlsConf, &Config{}) 135 Expect(err).ToNot(HaveOccurred()) 136 server := ln.baseServer 137 Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) 138 Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) 139 Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) 140 Expect(server.config.KeepAlivePeriod).To(BeZero()) 141 // stop the listener 142 Expect(ln.Close()).To(Succeed()) 143 }) 144 145 It("setups with the right values", func() { 146 supportedVersions := []protocol.Version{protocol.Version1} 147 config := Config{ 148 Versions: supportedVersions, 149 HandshakeIdleTimeout: 1337 * time.Hour, 150 MaxIdleTimeout: 42 * time.Minute, 151 KeepAlivePeriod: 5 * time.Second, 152 } 153 ln, err := Listen(conn, tlsConf, &config) 154 Expect(err).ToNot(HaveOccurred()) 155 server := ln.baseServer 156 Expect(server.connHandler).ToNot(BeNil()) 157 Expect(server.config.Versions).To(Equal(supportedVersions)) 158 Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) 159 Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) 160 Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) 161 // stop the listener 162 Expect(ln.Close()).To(Succeed()) 163 }) 164 165 It("listens on a given address", func() { 166 addr := "127.0.0.1:13579" 167 ln, err := ListenAddr(addr, tlsConf, &Config{}) 168 Expect(err).ToNot(HaveOccurred()) 169 Expect(ln.Addr().String()).To(Equal(addr)) 170 // stop the listener 171 Expect(ln.Close()).To(Succeed()) 172 }) 173 174 It("errors if given an invalid address", func() { 175 addr := "127.0.0.1" 176 _, err := ListenAddr(addr, tlsConf, &Config{}) 177 Expect(err).To(BeAssignableToTypeOf(&net.AddrError{})) 178 }) 179 180 It("errors if given an invalid address", func() { 181 addr := "1.1.1.1:1111" 182 _, err := ListenAddr(addr, tlsConf, &Config{}) 183 Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) 184 }) 185 186 Context("server accepting connections that completed the handshake", func() { 187 var ( 188 tr *Transport 189 serv *baseServer 190 phm *MockPacketHandlerManager 191 tracer *mocklogging.MockTracer 192 ) 193 194 BeforeEach(func() { 195 var t *logging.Tracer 196 t, tracer = mocklogging.NewMockTracer(mockCtrl) 197 tr = &Transport{Conn: conn, Tracer: t} 198 ln, err := tr.Listen(tlsConf, nil) 199 Expect(err).ToNot(HaveOccurred()) 200 serv = ln.baseServer 201 phm = NewMockPacketHandlerManager(mockCtrl) 202 serv.connHandler = phm 203 }) 204 205 AfterEach(func() { 206 tracer.EXPECT().Close() 207 tr.Close() 208 }) 209 210 Context("handling packets", func() { 211 It("drops Initial packets with a too short connection ID", func() { 212 p := getPacket(&wire.Header{ 213 Type: protocol.PacketTypeInitial, 214 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), 215 Version: serv.config.Versions[0], 216 }, nil) 217 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) 218 serv.handlePacket(p) 219 // make sure there are no Write calls on the packet conn 220 time.Sleep(50 * time.Millisecond) 221 }) 222 223 It("drops too small Initial", func() { 224 p := getPacket(&wire.Header{ 225 Type: protocol.PacketTypeInitial, 226 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), 227 Version: serv.config.Versions[0], 228 }, make([]byte, protocol.MinInitialPacketSize-100)) 229 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) 230 serv.handlePacket(p) 231 // make sure there are no Write calls on the packet conn 232 time.Sleep(50 * time.Millisecond) 233 }) 234 235 It("drops non-Initial packets", func() { 236 p := getPacket(&wire.Header{ 237 Type: protocol.PacketTypeHandshake, 238 Version: serv.config.Versions[0], 239 }, []byte("invalid")) 240 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket) 241 serv.handlePacket(p) 242 // make sure there are no Write calls on the packet conn 243 time.Sleep(50 * time.Millisecond) 244 }) 245 246 It("passes packets to existing connections", func() { 247 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 248 p := getPacket(&wire.Header{ 249 Type: protocol.PacketTypeInitial, 250 DestConnectionID: connID, 251 Version: serv.config.Versions[0], 252 }, make([]byte, protocol.MinInitialPacketSize)) 253 conn := NewMockPacketHandler(mockCtrl) 254 phm.EXPECT().Get(connID).Return(conn, true) 255 handled := make(chan struct{}) 256 conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) 257 serv.handlePacket(p) 258 Eventually(handled).Should(BeClosed()) 259 }) 260 261 It("creates a connection when the token is accepted", func() { 262 serv.maxNumHandshakesUnvalidated = 0 263 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 264 retryToken, err := serv.tokenGenerator.NewRetryToken( 265 raddr, 266 protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), 267 protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), 268 ) 269 Expect(err).ToNot(HaveOccurred()) 270 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 271 hdr := &wire.Header{ 272 Type: protocol.PacketTypeInitial, 273 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 274 DestConnectionID: connID, 275 Version: protocol.Version1, 276 Token: retryToken, 277 } 278 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 279 p.remoteAddr = raddr 280 run := make(chan struct{}) 281 var token protocol.StatelessResetToken 282 rand.Read(token[:]) 283 284 var newConnID protocol.ConnectionID 285 conn := NewMockQUICConn(mockCtrl) 286 serv.newConn = func( 287 _ sendConn, 288 _ connRunner, 289 origDestConnID protocol.ConnectionID, 290 retrySrcConnID *protocol.ConnectionID, 291 clientDestConnID protocol.ConnectionID, 292 destConnID protocol.ConnectionID, 293 srcConnID protocol.ConnectionID, 294 _ ConnectionIDGenerator, 295 tokenP protocol.StatelessResetToken, 296 _ *Config, 297 _ *tls.Config, 298 _ *handshake.TokenGenerator, 299 _ bool, 300 _ *logging.ConnectionTracer, 301 _ uint64, 302 _ utils.Logger, 303 _ protocol.Version, 304 ) quicConn { 305 Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))) 306 Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) 307 Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) 308 Expect(destConnID).To(Equal(hdr.SrcConnectionID)) 309 // make sure we're using a server-generated connection ID 310 Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) 311 Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) 312 newConnID = srcConnID 313 Expect(tokenP).To(Equal(token)) 314 conn.EXPECT().handlePacket(p) 315 conn.EXPECT().run().Do(func() error { close(run); return nil }) 316 conn.EXPECT().Context().Return(context.Background()) 317 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 318 return conn 319 } 320 phm.EXPECT().Get(connID) 321 phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token) 322 phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool { 323 Expect(cid).To(Equal(newConnID)) 324 return true 325 }) 326 327 done := make(chan struct{}) 328 go func() { 329 defer GinkgoRecover() 330 serv.handlePacket(p) 331 // the Handshake packet is written by the connection. 332 // Make sure there are no Write calls on the packet conn. 333 time.Sleep(50 * time.Millisecond) 334 close(done) 335 }() 336 // make sure we're using a server-generated connection ID 337 Eventually(run).Should(BeClosed()) 338 Eventually(done).Should(BeClosed()) 339 // shutdown 340 conn.EXPECT().closeWithTransportError(gomock.Any()) 341 }) 342 343 It("sends a Version Negotiation Packet for unsupported versions", func() { 344 srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) 345 destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) 346 packet := getPacket(&wire.Header{ 347 Type: protocol.PacketTypeHandshake, 348 SrcConnectionID: srcConnID, 349 DestConnectionID: destConnID, 350 Version: 0x42, 351 }, make([]byte, protocol.MinUnknownVersionPacketSize)) 352 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 353 packet.remoteAddr = raddr 354 tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.Version) { 355 Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) 356 Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) 357 }) 358 done := make(chan struct{}) 359 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 360 defer close(done) 361 Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) 362 dest, src, versions, err := wire.ParseVersionNegotiationPacket(b) 363 Expect(err).ToNot(HaveOccurred()) 364 Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) 365 Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) 366 Expect(versions).ToNot(ContainElement(protocol.Version(0x42))) 367 return len(b), nil 368 }) 369 serv.handlePacket(packet) 370 Eventually(done).Should(BeClosed()) 371 }) 372 373 It("doesn't send a Version Negotiation packets if sending them is disabled", func() { 374 serv.disableVersionNegotiation = true 375 srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) 376 destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) 377 packet := getPacket(&wire.Header{ 378 Type: protocol.PacketTypeHandshake, 379 SrcConnectionID: srcConnID, 380 DestConnectionID: destConnID, 381 Version: 0x42, 382 }, make([]byte, protocol.MinUnknownVersionPacketSize)) 383 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 384 packet.remoteAddr = raddr 385 done := make(chan struct{}) 386 serv.handlePacket(packet) 387 Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed()) 388 }) 389 390 It("ignores Version Negotiation packets", func() { 391 data := wire.ComposeVersionNegotiation( 392 protocol.ArbitraryLenConnectionID{1, 2, 3, 4}, 393 protocol.ArbitraryLenConnectionID{4, 3, 2, 1}, 394 []protocol.Version{1, 2, 3}, 395 ) 396 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 397 done := make(chan struct{}) 398 tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 399 close(done) 400 }) 401 serv.handlePacket(receivedPacket{ 402 remoteAddr: raddr, 403 data: data, 404 buffer: getPacketBuffer(), 405 }) 406 Eventually(done).Should(BeClosed()) 407 // make sure no other packet is sent 408 time.Sleep(scaleDuration(20 * time.Millisecond)) 409 }) 410 411 It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() { 412 srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) 413 destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) 414 p := getPacket(&wire.Header{ 415 Type: protocol.PacketTypeHandshake, 416 SrcConnectionID: srcConnID, 417 DestConnectionID: destConnID, 418 Version: 0x42, 419 }, make([]byte, protocol.MinUnknownVersionPacketSize-50)) 420 Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize)) 421 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 422 p.remoteAddr = raddr 423 done := make(chan struct{}) 424 tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 425 close(done) 426 }) 427 serv.handlePacket(p) 428 Eventually(done).Should(BeClosed()) 429 // make sure no other packet is sent 430 time.Sleep(scaleDuration(20 * time.Millisecond)) 431 }) 432 433 It("replies with a Retry packet, if a token is required", func() { 434 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 435 serv.maxNumHandshakesUnvalidated = 0 436 hdr := &wire.Header{ 437 Type: protocol.PacketTypeInitial, 438 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 439 DestConnectionID: connID, 440 Version: protocol.Version1, 441 } 442 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 443 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 444 packet.remoteAddr = raddr 445 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { 446 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 447 Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) 448 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 449 Expect(replyHdr.Token).ToNot(BeEmpty()) 450 }) 451 done := make(chan struct{}) 452 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 453 defer close(done) 454 replyHdr := parseHeader(b) 455 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 456 Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) 457 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 458 Expect(replyHdr.Token).ToNot(BeEmpty()) 459 Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:])) 460 return len(b), nil 461 }) 462 phm.EXPECT().Get(connID) 463 serv.handlePacket(packet) 464 Eventually(done).Should(BeClosed()) 465 }) 466 467 It("creates a connection, if no token is required", func() { 468 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 469 hdr := &wire.Header{ 470 Type: protocol.PacketTypeInitial, 471 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 472 DestConnectionID: connID, 473 Version: protocol.Version1, 474 } 475 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 476 run := make(chan struct{}) 477 var token protocol.StatelessResetToken 478 rand.Read(token[:]) 479 480 var newConnID protocol.ConnectionID 481 conn := NewMockQUICConn(mockCtrl) 482 serv.newConn = func( 483 _ sendConn, 484 _ connRunner, 485 origDestConnID protocol.ConnectionID, 486 retrySrcConnID *protocol.ConnectionID, 487 clientDestConnID protocol.ConnectionID, 488 destConnID protocol.ConnectionID, 489 srcConnID protocol.ConnectionID, 490 _ ConnectionIDGenerator, 491 tokenP protocol.StatelessResetToken, 492 _ *Config, 493 _ *tls.Config, 494 _ *handshake.TokenGenerator, 495 _ bool, 496 _ *logging.ConnectionTracer, 497 _ uint64, 498 _ utils.Logger, 499 _ protocol.Version, 500 ) quicConn { 501 Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) 502 Expect(retrySrcConnID).To(BeNil()) 503 Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) 504 Expect(destConnID).To(Equal(hdr.SrcConnectionID)) 505 // make sure we're using a server-generated connection ID 506 Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) 507 Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) 508 newConnID = srcConnID 509 Expect(tokenP).To(Equal(token)) 510 conn.EXPECT().handlePacket(p) 511 conn.EXPECT().run().Do(func() error { close(run); return nil }) 512 conn.EXPECT().Context().Return(context.Background()) 513 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 514 return conn 515 } 516 gomock.InOrder( 517 phm.EXPECT().Get(connID), 518 phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token), 519 phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool { 520 Expect(c).To(Equal(newConnID)) 521 return true 522 }), 523 ) 524 525 done := make(chan struct{}) 526 go func() { 527 defer GinkgoRecover() 528 serv.handlePacket(p) 529 // the Handshake packet is written by the connection 530 // make sure there are no Write calls on the packet conn 531 time.Sleep(50 * time.Millisecond) 532 close(done) 533 }() 534 // make sure we're using a server-generated connection ID 535 Eventually(run).Should(BeClosed()) 536 Eventually(done).Should(BeClosed()) 537 // shutdown 538 conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) 539 }) 540 541 It("drops packets if the receive queue is full", func() { 542 serv.maxNumHandshakesTotal = 10000 543 serv.maxNumHandshakesUnvalidated = 10000 544 545 phm.EXPECT().Get(gomock.Any()).AnyTimes() 546 phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() 547 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() 548 549 acceptConn := make(chan struct{}) 550 var counter atomic.Uint32 551 serv.newConn = func( 552 _ sendConn, 553 runner connRunner, 554 _ protocol.ConnectionID, 555 _ *protocol.ConnectionID, 556 _ protocol.ConnectionID, 557 _ protocol.ConnectionID, 558 _ protocol.ConnectionID, 559 _ ConnectionIDGenerator, 560 _ protocol.StatelessResetToken, 561 _ *Config, 562 _ *tls.Config, 563 _ *handshake.TokenGenerator, 564 _ bool, 565 _ *logging.ConnectionTracer, 566 _ uint64, 567 _ utils.Logger, 568 _ protocol.Version, 569 ) quicConn { 570 <-acceptConn 571 counter.Add(1) 572 conn := NewMockQUICConn(mockCtrl) 573 conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) 574 conn.EXPECT().run().MaxTimes(1) 575 conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) 576 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1) 577 // shutdown 578 conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) 579 return conn 580 } 581 582 p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})) 583 serv.handlePacket(p) 584 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1) 585 var wg sync.WaitGroup 586 for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ { 587 wg.Add(1) 588 go func() { 589 defer GinkgoRecover() 590 defer wg.Done() 591 serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))) 592 }() 593 } 594 wg.Wait() 595 596 close(acceptConn) 597 Eventually( 598 func() uint32 { return counter.Load() }, 599 scaleDuration(100*time.Millisecond), 600 ).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) 601 Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) 602 }) 603 604 PIt("only creates a single connection for a duplicate Initial", func() { 605 var createdConn bool 606 serv.newConn = func( 607 _ sendConn, 608 runner connRunner, 609 _ protocol.ConnectionID, 610 _ *protocol.ConnectionID, 611 _ protocol.ConnectionID, 612 _ protocol.ConnectionID, 613 _ protocol.ConnectionID, 614 _ ConnectionIDGenerator, 615 _ protocol.StatelessResetToken, 616 _ *Config, 617 _ *tls.Config, 618 _ *handshake.TokenGenerator, 619 _ bool, 620 _ *logging.ConnectionTracer, 621 _ uint64, 622 _ utils.Logger, 623 _ protocol.Version, 624 ) quicConn { 625 createdConn = true 626 return NewMockQUICConn(mockCtrl) 627 } 628 629 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) 630 p := getInitial(connID) 631 phm.EXPECT().Get(connID) 632 phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision 633 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) 634 done := make(chan struct{}) 635 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) (int, error) { close(done); return 0, nil }) 636 Expect(serv.handlePacketImpl(p)).To(BeTrue()) 637 Expect(createdConn).To(BeFalse()) 638 Eventually(done).Should(BeClosed()) 639 }) 640 641 It("limits the number of unvalidated handshakes", func() { 642 const limit = 3 643 serv.maxNumHandshakesTotal = 10000 644 serv.maxNumHandshakesUnvalidated = limit 645 646 phm.EXPECT().Get(gomock.Any()).AnyTimes() 647 phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() 648 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() 649 650 handshakeChan := make(chan struct{}) 651 connChan := make(chan *MockQUICConn, 1) 652 var wg sync.WaitGroup 653 wg.Add(2 * limit) 654 serv.newConn = func( 655 _ sendConn, 656 runner connRunner, 657 _ protocol.ConnectionID, 658 _ *protocol.ConnectionID, 659 _ protocol.ConnectionID, 660 _ protocol.ConnectionID, 661 _ protocol.ConnectionID, 662 _ ConnectionIDGenerator, 663 _ protocol.StatelessResetToken, 664 _ *Config, 665 _ *tls.Config, 666 _ *handshake.TokenGenerator, 667 _ bool, 668 _ *logging.ConnectionTracer, 669 _ uint64, 670 _ utils.Logger, 671 _ protocol.Version, 672 ) quicConn { 673 conn := <-connChan 674 conn.EXPECT().handlePacket(gomock.Any()) 675 conn.EXPECT().run() 676 conn.EXPECT().Context().Return(context.Background()) 677 conn.EXPECT().HandshakeComplete().Return(handshakeChan).Do(func() <-chan struct{} { wg.Done(); return nil }) 678 return conn 679 } 680 681 // Initiate the maximum number of allowed connection attempts. 682 for i := 0; i < limit; i++ { 683 conn := NewMockQUICConn(mockCtrl) 684 connChan <- conn 685 serv.handlePacket(getInitialWithRandomDestConnID()) 686 } 687 688 // Now initiate another connection attempt. 689 p := getInitialWithRandomDestConnID() 690 done := make(chan struct{}) 691 tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 692 defer GinkgoRecover() 693 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 694 }) 695 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 696 defer GinkgoRecover() 697 defer close(done) 698 hdr, _, _, err := wire.ParsePacket(b) 699 Expect(err).ToNot(HaveOccurred()) 700 Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) 701 return len(b), nil 702 }) 703 serv.handlePacket(p) 704 Eventually(done).Should(BeClosed()) 705 706 close(handshakeChan) 707 for i := 0; i < limit; i++ { 708 _, err := serv.Accept(context.Background()) 709 Expect(err).ToNot(HaveOccurred()) 710 } 711 for i := 0; i < limit; i++ { 712 conn := NewMockQUICConn(mockCtrl) 713 conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed 714 connChan <- conn 715 serv.handlePacket(getInitialWithRandomDestConnID()) 716 } 717 wg.Wait() 718 }) 719 720 It("limits the number of total handshakes", func() { 721 const limit = 3 722 serv.maxNumHandshakesTotal = limit 723 serv.maxNumHandshakesUnvalidated = limit // same limit, but we check that we send CONNECTION_REFUSED and not Retry 724 725 phm.EXPECT().Get(gomock.Any()).AnyTimes() 726 phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() 727 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() 728 729 handshakeChan := make(chan struct{}) 730 connChan := make(chan *MockQUICConn, 1) 731 serv.newConn = func( 732 _ sendConn, 733 runner connRunner, 734 _ protocol.ConnectionID, 735 _ *protocol.ConnectionID, 736 _ protocol.ConnectionID, 737 _ protocol.ConnectionID, 738 _ protocol.ConnectionID, 739 _ ConnectionIDGenerator, 740 _ protocol.StatelessResetToken, 741 _ *Config, 742 _ *tls.Config, 743 _ *handshake.TokenGenerator, 744 _ bool, 745 _ *logging.ConnectionTracer, 746 _ uint64, 747 _ utils.Logger, 748 _ protocol.Version, 749 ) quicConn { 750 conn := <-connChan 751 conn.EXPECT().handlePacket(gomock.Any()) 752 conn.EXPECT().run() 753 conn.EXPECT().Context().Return(context.Background()) 754 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 755 return conn 756 } 757 758 for i := 0; i < limit; i++ { 759 conn := NewMockQUICConn(mockCtrl) 760 connChan <- conn 761 serv.handlePacket(getInitialWithRandomDestConnID()) 762 } 763 764 p := getInitialWithRandomDestConnID() 765 done := make(chan struct{}) 766 tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 767 defer GinkgoRecover() 768 hdr, _, _, err := wire.ParsePacket(p.data) 769 Expect(err).ToNot(HaveOccurred()) 770 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 771 Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 772 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 773 Expect(frames).To(HaveLen(1)) 774 Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 775 ccf := frames[0].(*logging.ConnectionCloseFrame) 776 Expect(ccf.IsApplicationError).To(BeFalse()) 777 Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ConnectionRefused)) 778 }) 779 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 780 defer GinkgoRecover() 781 defer close(done) 782 hdr, _, _, err := wire.ParsePacket(p.data) 783 Expect(err).ToNot(HaveOccurred()) 784 checkConnectionCloseError(b, hdr, qerr.ConnectionRefused) 785 return len(b), nil 786 }) 787 serv.handlePacket(p) 788 Eventually(done).Should(BeClosed()) 789 790 close(handshakeChan) 791 for i := 0; i < limit; i++ { 792 _, err := serv.Accept(context.Background()) 793 Expect(err).ToNot(HaveOccurred()) 794 } 795 // make sure we can enqueue and accept more connections after that 796 for i := 0; i < limit; i++ { 797 conn := NewMockQUICConn(mockCtrl) 798 conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed 799 connChan <- conn 800 serv.handlePacket(getInitialWithRandomDestConnID()) 801 } 802 for i := 0; i < limit; i++ { 803 _, err := serv.Accept(context.Background()) 804 Expect(err).ToNot(HaveOccurred()) 805 } 806 }) 807 }) 808 809 Context("token validation", func() { 810 It("decodes the token from the token field", func() { 811 raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} 812 token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) 813 Expect(err).ToNot(HaveOccurred()) 814 packet := getPacket(&wire.Header{ 815 Type: protocol.PacketTypeInitial, 816 Token: token, 817 Version: serv.config.Versions[0], 818 }, make([]byte, protocol.MinInitialPacketSize)) 819 packet.remoteAddr = raddr 820 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) 821 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) 822 823 done := make(chan struct{}) 824 phm.EXPECT().Get(gomock.Any()) 825 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 826 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool { 827 close(done) 828 return true 829 }) 830 phm.EXPECT().Remove(gomock.Any()).AnyTimes() 831 serv.handlePacket(packet) 832 Eventually(done).Should(BeClosed()) 833 }) 834 835 It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { 836 serv.maxNumHandshakesUnvalidated = 0 837 token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) 838 Expect(err).ToNot(HaveOccurred()) 839 hdr := &wire.Header{ 840 Type: protocol.PacketTypeInitial, 841 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 842 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 843 Token: token, 844 Version: protocol.Version1, 845 } 846 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 847 packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet 848 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 849 packet.remoteAddr = raddr 850 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 851 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 852 Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 853 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 854 Expect(frames).To(HaveLen(1)) 855 Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 856 ccf := frames[0].(*logging.ConnectionCloseFrame) 857 Expect(ccf.IsApplicationError).To(BeFalse()) 858 Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) 859 }) 860 done := make(chan struct{}) 861 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 862 defer close(done) 863 checkConnectionCloseError(b, hdr, qerr.InvalidToken) 864 return len(b), nil 865 }) 866 phm.EXPECT().Get(gomock.Any()) 867 serv.handlePacket(packet) 868 Eventually(done).Should(BeClosed()) 869 }) 870 871 It("sends an INVALID_TOKEN error, if an expired retry token is received", func() { 872 serv.maxNumHandshakesUnvalidated = 0 873 serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout 874 Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond)) 875 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 876 token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) 877 Expect(err).ToNot(HaveOccurred()) 878 time.Sleep(2 * time.Millisecond) // make sure the token is expired 879 hdr := &wire.Header{ 880 Type: protocol.PacketTypeInitial, 881 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 882 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 883 Token: token, 884 Version: protocol.Version1, 885 } 886 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 887 packet.remoteAddr = raddr 888 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 889 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 890 Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 891 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 892 Expect(frames).To(HaveLen(1)) 893 Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 894 ccf := frames[0].(*logging.ConnectionCloseFrame) 895 Expect(ccf.IsApplicationError).To(BeFalse()) 896 Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) 897 }) 898 done := make(chan struct{}) 899 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 900 defer close(done) 901 checkConnectionCloseError(b, hdr, qerr.InvalidToken) 902 return len(b), nil 903 }) 904 phm.EXPECT().Get(gomock.Any()) 905 serv.handlePacket(packet) 906 Eventually(done).Should(BeClosed()) 907 }) 908 909 It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() { 910 serv.maxNumHandshakesUnvalidated = 0 911 token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}) 912 Expect(err).ToNot(HaveOccurred()) 913 hdr := &wire.Header{ 914 Type: protocol.PacketTypeInitial, 915 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 916 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 917 Token: token, 918 Version: protocol.Version1, 919 } 920 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 921 packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet 922 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 923 packet.remoteAddr = raddr 924 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) 925 done := make(chan struct{}) 926 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 927 defer close(done) 928 replyHdr := parseHeader(b) 929 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 930 return len(b), nil 931 }) 932 phm.EXPECT().Get(gomock.Any()) 933 serv.handlePacket(packet) 934 // make sure there are no Write calls on the packet conn 935 Eventually(done).Should(BeClosed()) 936 }) 937 938 It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() { 939 serv.maxNumHandshakesUnvalidated = 0 940 serv.maxTokenAge = time.Millisecond 941 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 942 token, err := serv.tokenGenerator.NewToken(raddr) 943 Expect(err).ToNot(HaveOccurred()) 944 time.Sleep(2 * time.Millisecond) // make sure the token is expired 945 hdr := &wire.Header{ 946 Type: protocol.PacketTypeInitial, 947 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 948 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 949 Token: token, 950 Version: protocol.Version1, 951 } 952 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 953 packet.remoteAddr = raddr 954 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 955 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 956 }) 957 done := make(chan struct{}) 958 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 959 defer close(done) 960 return len(b), nil 961 }) 962 phm.EXPECT().Get(gomock.Any()) 963 serv.handlePacket(packet) 964 Eventually(done).Should(BeClosed()) 965 }) 966 967 It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { 968 serv.maxNumHandshakesUnvalidated = 0 969 token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) 970 Expect(err).ToNot(HaveOccurred()) 971 hdr := &wire.Header{ 972 Type: protocol.PacketTypeInitial, 973 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 974 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 975 Token: token, 976 Version: protocol.Version1, 977 } 978 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 979 packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet 980 packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 981 done := make(chan struct{}) 982 tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) 983 phm.EXPECT().Get(gomock.Any()) 984 serv.handlePacket(packet) 985 // make sure there are no Write calls on the packet conn 986 time.Sleep(50 * time.Millisecond) 987 Eventually(done).Should(BeClosed()) 988 }) 989 }) 990 991 Context("accepting connections", func() { 992 It("returns Accept when closed", func() { 993 done := make(chan struct{}) 994 go func() { 995 defer GinkgoRecover() 996 _, err := serv.Accept(context.Background()) 997 Expect(err).To(MatchError(ErrServerClosed)) 998 close(done) 999 }() 1000 1001 serv.Close() 1002 Eventually(done).Should(BeClosed()) 1003 }) 1004 1005 It("returns immediately, if an error occurred before", func() { 1006 serv.Close() 1007 for i := 0; i < 3; i++ { 1008 _, err := serv.Accept(context.Background()) 1009 Expect(err).To(MatchError(ErrServerClosed)) 1010 } 1011 }) 1012 1013 It("closes connection that are still handshaking after Close", func() { 1014 serv.Close() 1015 1016 destroyed := make(chan struct{}) 1017 serv.newConn = func( 1018 _ sendConn, 1019 _ connRunner, 1020 _ protocol.ConnectionID, 1021 _ *protocol.ConnectionID, 1022 _ protocol.ConnectionID, 1023 _ protocol.ConnectionID, 1024 _ protocol.ConnectionID, 1025 _ ConnectionIDGenerator, 1026 _ protocol.StatelessResetToken, 1027 conf *Config, 1028 _ *tls.Config, 1029 _ *handshake.TokenGenerator, 1030 _ bool, 1031 _ *logging.ConnectionTracer, 1032 _ uint64, 1033 _ utils.Logger, 1034 _ protocol.Version, 1035 ) quicConn { 1036 conn := NewMockQUICConn(mockCtrl) 1037 conn.EXPECT().handlePacket(gomock.Any()) 1038 conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) }) 1039 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 1040 conn.EXPECT().run().MaxTimes(1) 1041 conn.EXPECT().Context().Return(context.Background()) 1042 return conn 1043 } 1044 phm.EXPECT().Get(gomock.Any()) 1045 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1046 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1047 serv.handleInitialImpl( 1048 receivedPacket{buffer: getPacketBuffer()}, 1049 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1050 ) 1051 Eventually(destroyed).Should(BeClosed()) 1052 }) 1053 1054 It("returns when the context is canceled", func() { 1055 ctx, cancel := context.WithCancel(context.Background()) 1056 done := make(chan struct{}) 1057 go func() { 1058 defer GinkgoRecover() 1059 _, err := serv.Accept(ctx) 1060 Expect(err).To(MatchError("context canceled")) 1061 close(done) 1062 }() 1063 1064 Consistently(done).ShouldNot(BeClosed()) 1065 cancel() 1066 Eventually(done).Should(BeClosed()) 1067 }) 1068 1069 It("uses the config returned by GetConfigClient", func() { 1070 conn := NewMockQUICConn(mockCtrl) 1071 1072 conf := &Config{MaxIncomingStreams: 1234} 1073 serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }}) 1074 done := make(chan struct{}) 1075 go func() { 1076 defer GinkgoRecover() 1077 s, err := serv.Accept(context.Background()) 1078 Expect(err).ToNot(HaveOccurred()) 1079 Expect(s).To(Equal(conn)) 1080 close(done) 1081 }() 1082 1083 handshakeChan := make(chan struct{}) 1084 serv.newConn = func( 1085 _ sendConn, 1086 _ connRunner, 1087 _ protocol.ConnectionID, 1088 _ *protocol.ConnectionID, 1089 _ protocol.ConnectionID, 1090 _ protocol.ConnectionID, 1091 _ protocol.ConnectionID, 1092 _ ConnectionIDGenerator, 1093 _ protocol.StatelessResetToken, 1094 conf *Config, 1095 _ *tls.Config, 1096 _ *handshake.TokenGenerator, 1097 _ bool, 1098 _ *logging.ConnectionTracer, 1099 _ uint64, 1100 _ utils.Logger, 1101 _ protocol.Version, 1102 ) quicConn { 1103 Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234)) 1104 conn.EXPECT().handlePacket(gomock.Any()) 1105 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 1106 conn.EXPECT().run() 1107 conn.EXPECT().Context().Return(context.Background()) 1108 return conn 1109 } 1110 phm.EXPECT().Get(gomock.Any()) 1111 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1112 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1113 serv.handleInitialImpl( 1114 receivedPacket{buffer: getPacketBuffer()}, 1115 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1116 ) 1117 Consistently(done).ShouldNot(BeClosed()) 1118 close(handshakeChan) // complete the handshake 1119 Eventually(done).Should(BeClosed()) 1120 }) 1121 1122 It("rejects a connection attempt when GetConfigClient returns an error", func() { 1123 serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }}) 1124 1125 phm.EXPECT().Get(gomock.Any()) 1126 done := make(chan struct{}) 1127 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) 1128 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 1129 defer close(done) 1130 rejectHdr := parseHeader(b) 1131 Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) 1132 return len(b), nil 1133 }) 1134 serv.handleInitialImpl( 1135 receivedPacket{buffer: getPacketBuffer()}, 1136 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1}, 1137 ) 1138 Eventually(done).Should(BeClosed()) 1139 }) 1140 1141 It("accepts new connections when the handshake completes", func() { 1142 conn := NewMockQUICConn(mockCtrl) 1143 1144 done := make(chan struct{}) 1145 go func() { 1146 defer GinkgoRecover() 1147 s, err := serv.Accept(context.Background()) 1148 Expect(err).ToNot(HaveOccurred()) 1149 Expect(s).To(Equal(conn)) 1150 close(done) 1151 }() 1152 1153 handshakeChan := make(chan struct{}) 1154 serv.newConn = func( 1155 _ sendConn, 1156 runner connRunner, 1157 _ protocol.ConnectionID, 1158 _ *protocol.ConnectionID, 1159 _ protocol.ConnectionID, 1160 _ protocol.ConnectionID, 1161 _ protocol.ConnectionID, 1162 _ ConnectionIDGenerator, 1163 _ protocol.StatelessResetToken, 1164 _ *Config, 1165 _ *tls.Config, 1166 _ *handshake.TokenGenerator, 1167 _ bool, 1168 _ *logging.ConnectionTracer, 1169 _ uint64, 1170 _ utils.Logger, 1171 _ protocol.Version, 1172 ) quicConn { 1173 conn.EXPECT().handlePacket(gomock.Any()) 1174 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 1175 conn.EXPECT().run() 1176 conn.EXPECT().Context().Return(context.Background()) 1177 return conn 1178 } 1179 phm.EXPECT().Get(gomock.Any()) 1180 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1181 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1182 serv.handleInitialImpl( 1183 receivedPacket{buffer: getPacketBuffer()}, 1184 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1185 ) 1186 Consistently(done).ShouldNot(BeClosed()) 1187 close(handshakeChan) // complete the handshake 1188 Eventually(done).Should(BeClosed()) 1189 }) 1190 }) 1191 }) 1192 1193 Context("server accepting connections that haven't completed the handshake", func() { 1194 var ( 1195 serv *EarlyListener 1196 phm *MockPacketHandlerManager 1197 ) 1198 1199 BeforeEach(func() { 1200 var err error 1201 serv, err = ListenEarly(conn, tlsConf, nil) 1202 Expect(err).ToNot(HaveOccurred()) 1203 phm = NewMockPacketHandlerManager(mockCtrl) 1204 serv.baseServer.connHandler = phm 1205 }) 1206 1207 AfterEach(func() { 1208 serv.Close() 1209 }) 1210 1211 It("accepts new connections when they become ready", func() { 1212 conn := NewMockQUICConn(mockCtrl) 1213 1214 done := make(chan struct{}) 1215 go func() { 1216 defer GinkgoRecover() 1217 s, err := serv.Accept(context.Background()) 1218 Expect(err).ToNot(HaveOccurred()) 1219 Expect(s).To(Equal(conn)) 1220 close(done) 1221 }() 1222 1223 ready := make(chan struct{}) 1224 serv.baseServer.newConn = func( 1225 _ sendConn, 1226 runner connRunner, 1227 _ protocol.ConnectionID, 1228 _ *protocol.ConnectionID, 1229 _ protocol.ConnectionID, 1230 _ protocol.ConnectionID, 1231 _ protocol.ConnectionID, 1232 _ ConnectionIDGenerator, 1233 _ protocol.StatelessResetToken, 1234 _ *Config, 1235 _ *tls.Config, 1236 _ *handshake.TokenGenerator, 1237 _ bool, 1238 _ *logging.ConnectionTracer, 1239 _ uint64, 1240 _ utils.Logger, 1241 _ protocol.Version, 1242 ) quicConn { 1243 conn.EXPECT().handlePacket(gomock.Any()) 1244 conn.EXPECT().run() 1245 conn.EXPECT().earlyConnReady().Return(ready) 1246 conn.EXPECT().Context().Return(context.Background()) 1247 return conn 1248 } 1249 phm.EXPECT().Get(gomock.Any()) 1250 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1251 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1252 serv.baseServer.handleInitialImpl( 1253 receivedPacket{buffer: getPacketBuffer()}, 1254 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1255 ) 1256 Consistently(done).ShouldNot(BeClosed()) 1257 close(ready) 1258 Eventually(done).Should(BeClosed()) 1259 }) 1260 1261 It("rejects new connection attempts if the accept queue is full", func() { 1262 connChan := make(chan *MockQUICConn, 1) 1263 var wg sync.WaitGroup // to make sure the test fully completes 1264 wg.Add(protocol.MaxAcceptQueueSize + 1) 1265 serv.baseServer.newConn = func( 1266 _ sendConn, 1267 runner connRunner, 1268 _ protocol.ConnectionID, 1269 _ *protocol.ConnectionID, 1270 _ protocol.ConnectionID, 1271 _ protocol.ConnectionID, 1272 _ protocol.ConnectionID, 1273 _ ConnectionIDGenerator, 1274 _ protocol.StatelessResetToken, 1275 _ *Config, 1276 _ *tls.Config, 1277 _ *handshake.TokenGenerator, 1278 _ bool, 1279 _ *logging.ConnectionTracer, 1280 _ uint64, 1281 _ utils.Logger, 1282 _ protocol.Version, 1283 ) quicConn { 1284 defer wg.Done() 1285 ready := make(chan struct{}) 1286 close(ready) 1287 conn := <-connChan 1288 conn.EXPECT().handlePacket(gomock.Any()) 1289 conn.EXPECT().run() 1290 conn.EXPECT().earlyConnReady().Return(ready) 1291 conn.EXPECT().Context().Return(context.Background()) 1292 return conn 1293 } 1294 1295 phm.EXPECT().Get(gomock.Any()).AnyTimes() 1296 phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) 1297 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize) 1298 for i := 0; i < protocol.MaxAcceptQueueSize; i++ { 1299 conn := NewMockQUICConn(mockCtrl) 1300 connChan <- conn 1301 serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) 1302 } 1303 1304 Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize)) 1305 1306 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1307 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1308 conn := NewMockQUICConn(mockCtrl) 1309 conn.EXPECT().closeWithTransportError(ConnectionRefused) 1310 connChan <- conn 1311 serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) 1312 wg.Wait() 1313 }) 1314 1315 It("doesn't accept new connections if they were closed in the mean time", func() { 1316 p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) 1317 ctx, cancel := context.WithCancel(context.Background()) 1318 connCreated := make(chan struct{}) 1319 conn := NewMockQUICConn(mockCtrl) 1320 serv.baseServer.newConn = func( 1321 _ sendConn, 1322 runner connRunner, 1323 _ protocol.ConnectionID, 1324 _ *protocol.ConnectionID, 1325 _ protocol.ConnectionID, 1326 _ protocol.ConnectionID, 1327 _ protocol.ConnectionID, 1328 _ ConnectionIDGenerator, 1329 _ protocol.StatelessResetToken, 1330 _ *Config, 1331 _ *tls.Config, 1332 _ *handshake.TokenGenerator, 1333 _ bool, 1334 _ *logging.ConnectionTracer, 1335 _ uint64, 1336 _ utils.Logger, 1337 _ protocol.Version, 1338 ) quicConn { 1339 conn.EXPECT().handlePacket(p) 1340 conn.EXPECT().run() 1341 conn.EXPECT().earlyConnReady() 1342 conn.EXPECT().Context().Return(ctx) 1343 close(connCreated) 1344 return conn 1345 } 1346 1347 phm.EXPECT().Get(gomock.Any()) 1348 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1349 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1350 serv.baseServer.handlePacket(p) 1351 // make sure there are no Write calls on the packet conn 1352 time.Sleep(50 * time.Millisecond) 1353 Eventually(connCreated).Should(BeClosed()) 1354 cancel() 1355 time.Sleep(scaleDuration(200 * time.Millisecond)) 1356 1357 done := make(chan struct{}) 1358 go func() { 1359 defer GinkgoRecover() 1360 serv.Accept(context.Background()) 1361 close(done) 1362 }() 1363 Consistently(done).ShouldNot(BeClosed()) 1364 1365 // make the go routine return 1366 conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID 1367 Expect(serv.Close()).To(Succeed()) 1368 Eventually(done).Should(BeClosed()) 1369 }) 1370 }) 1371 1372 Context("0-RTT", func() { 1373 var ( 1374 tr *Transport 1375 serv *baseServer 1376 phm *MockPacketHandlerManager 1377 tracer *mocklogging.MockTracer 1378 ) 1379 1380 BeforeEach(func() { 1381 var t *logging.Tracer 1382 t, tracer = mocklogging.NewMockTracer(mockCtrl) 1383 tr = &Transport{Conn: conn, Tracer: t} 1384 ln, err := tr.ListenEarly(tlsConf, nil) 1385 Expect(err).ToNot(HaveOccurred()) 1386 phm = NewMockPacketHandlerManager(mockCtrl) 1387 serv = ln.baseServer 1388 serv.connHandler = phm 1389 }) 1390 1391 AfterEach(func() { 1392 tracer.EXPECT().Close() 1393 Expect(tr.Close()).To(Succeed()) 1394 }) 1395 1396 It("passes packets to existing connections", func() { 1397 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1398 p := getPacket(&wire.Header{ 1399 Type: protocol.PacketType0RTT, 1400 DestConnectionID: connID, 1401 Version: serv.config.Versions[0], 1402 }, make([]byte, 100)) 1403 conn := NewMockPacketHandler(mockCtrl) 1404 phm.EXPECT().Get(connID).Return(conn, true) 1405 handled := make(chan struct{}) 1406 conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) 1407 serv.handlePacket(p) 1408 Eventually(handled).Should(BeClosed()) 1409 }) 1410 1411 It("queues 0-RTT packets, up to Max0RTTQueueSize", func() { 1412 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1413 1414 var zeroRTTPackets []receivedPacket 1415 1416 for i := 0; i < protocol.Max0RTTQueueLen; i++ { 1417 p := getPacket(&wire.Header{ 1418 Type: protocol.PacketType0RTT, 1419 DestConnectionID: connID, 1420 Version: serv.config.Versions[0], 1421 }, make([]byte, 100+i)) 1422 phm.EXPECT().Get(connID) 1423 serv.handlePacket(p) 1424 zeroRTTPackets = append(zeroRTTPackets, p) 1425 } 1426 1427 // send one more packet, this one should be dropped 1428 p := getPacket(&wire.Header{ 1429 Type: protocol.PacketType0RTT, 1430 DestConnectionID: connID, 1431 Version: serv.config.Versions[0], 1432 }, make([]byte, 200)) 1433 phm.EXPECT().Get(connID) 1434 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) 1435 serv.handlePacket(p) 1436 1437 initial := getPacket(&wire.Header{ 1438 Type: protocol.PacketTypeInitial, 1439 DestConnectionID: connID, 1440 Version: serv.config.Versions[0], 1441 }, make([]byte, protocol.MinInitialPacketSize)) 1442 called := make(chan struct{}) 1443 serv.newConn = func( 1444 _ sendConn, 1445 _ connRunner, 1446 _ protocol.ConnectionID, 1447 _ *protocol.ConnectionID, 1448 _ protocol.ConnectionID, 1449 _ protocol.ConnectionID, 1450 _ protocol.ConnectionID, 1451 _ ConnectionIDGenerator, 1452 _ protocol.StatelessResetToken, 1453 _ *Config, 1454 _ *tls.Config, 1455 _ *handshake.TokenGenerator, 1456 _ bool, 1457 _ *logging.ConnectionTracer, 1458 _ uint64, 1459 _ utils.Logger, 1460 _ protocol.Version, 1461 ) quicConn { 1462 conn := NewMockQUICConn(mockCtrl) 1463 var calls []any 1464 calls = append(calls, conn.EXPECT().handlePacket(initial)) 1465 for _, p := range zeroRTTPackets { 1466 calls = append(calls, conn.EXPECT().handlePacket(p)) 1467 } 1468 gomock.InOrder(calls...) 1469 conn.EXPECT().run() 1470 conn.EXPECT().earlyConnReady() 1471 conn.EXPECT().Context().Return(context.Background()) 1472 close(called) 1473 // shutdown 1474 conn.EXPECT().closeWithTransportError(gomock.Any()) 1475 return conn 1476 } 1477 1478 phm.EXPECT().Get(connID) 1479 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1480 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1481 serv.handlePacket(initial) 1482 Eventually(called).Should(BeClosed()) 1483 }) 1484 1485 It("limits the number of queues", func() { 1486 for i := 0; i < protocol.Max0RTTQueues; i++ { 1487 b := make([]byte, 16) 1488 rand.Read(b) 1489 connID := protocol.ParseConnectionID(b) 1490 p := getPacket(&wire.Header{ 1491 Type: protocol.PacketType0RTT, 1492 DestConnectionID: connID, 1493 Version: serv.config.Versions[0], 1494 }, make([]byte, 100+i)) 1495 phm.EXPECT().Get(connID) 1496 serv.handlePacket(p) 1497 } 1498 1499 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1500 p := getPacket(&wire.Header{ 1501 Type: protocol.PacketType0RTT, 1502 DestConnectionID: connID, 1503 Version: serv.config.Versions[0], 1504 }, make([]byte, 200)) 1505 phm.EXPECT().Get(connID) 1506 dropped := make(chan struct{}) 1507 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 1508 close(dropped) 1509 }) 1510 serv.handlePacket(p) 1511 Eventually(dropped).Should(BeClosed()) 1512 }) 1513 1514 It("drops queues after a while", func() { 1515 now := time.Now() 1516 1517 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1518 p := getPacket(&wire.Header{ 1519 Type: protocol.PacketType0RTT, 1520 DestConnectionID: connID, 1521 Version: serv.config.Versions[0], 1522 }, make([]byte, 200)) 1523 p.rcvTime = now 1524 1525 connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9}) 1526 p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2) 1527 p2 := getPacket(&wire.Header{ 1528 Type: protocol.PacketType0RTT, 1529 DestConnectionID: connID2, 1530 Version: serv.config.Versions[0], 1531 }, make([]byte, 300)) 1532 p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet 1533 1534 dropped1 := make(chan struct{}) 1535 dropped2 := make(chan struct{}) 1536 // need to register the call before handling the packet to avoid race condition 1537 gomock.InOrder( 1538 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 1539 close(dropped1) 1540 }), 1541 tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 1542 close(dropped2) 1543 }), 1544 ) 1545 1546 phm.EXPECT().Get(connID) 1547 serv.handlePacket(p) 1548 1549 // There's no cleanup Go routine. 1550 // Cleanup is triggered when new packets are received. 1551 1552 phm.EXPECT().Get(connID2) 1553 serv.handlePacket(p2) 1554 // make sure no cleanup is executed 1555 Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed()) 1556 1557 // There's no cleanup Go routine. 1558 // Cleanup is triggered when new packets are received. 1559 connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0}) 1560 p3 := getPacket(&wire.Header{ 1561 Type: protocol.PacketType0RTT, 1562 DestConnectionID: connID3, 1563 Version: serv.config.Versions[0], 1564 }, make([]byte, 200)) 1565 p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup 1566 phm.EXPECT().Get(connID3) 1567 serv.handlePacket(p3) 1568 Eventually(dropped1).Should(BeClosed()) 1569 Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed()) 1570 1571 // make sure the second packet is also cleaned up 1572 connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1}) 1573 p4 := getPacket(&wire.Header{ 1574 Type: protocol.PacketType0RTT, 1575 DestConnectionID: connID4, 1576 Version: serv.config.Versions[0], 1577 }, make([]byte, 200)) 1578 p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup 1579 phm.EXPECT().Get(connID4) 1580 serv.handlePacket(p4) 1581 Eventually(dropped2).Should(BeClosed()) 1582 }) 1583 }) 1584 })