github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/server_test.go (about) 1 package quic 2 3 import ( 4 "context" 5 "crypto/rand" 6 "crypto/tls" 7 "errors" 8 "net" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "golang.org/x/time/rate" 14 15 "github.com/apernet/quic-go/internal/handshake" 16 mocklogging "github.com/apernet/quic-go/internal/mocks/logging" 17 "github.com/apernet/quic-go/internal/protocol" 18 "github.com/apernet/quic-go/internal/qerr" 19 "github.com/apernet/quic-go/internal/testdata" 20 "github.com/apernet/quic-go/internal/utils" 21 "github.com/apernet/quic-go/internal/wire" 22 "github.com/apernet/quic-go/logging" 23 24 . "github.com/onsi/ginkgo/v2" 25 . "github.com/onsi/gomega" 26 "go.uber.org/mock/gomock" 27 ) 28 29 var _ = Describe("Server", func() { 30 var ( 31 conn *MockPacketConn 32 tlsConf *tls.Config 33 ) 34 35 getPacket := func(hdr *wire.Header, p []byte) receivedPacket { 36 buf := getPacketBuffer() 37 hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 38 var err error 39 buf.Data, err = (&wire.ExtendedHeader{ 40 Header: *hdr, 41 PacketNumber: 0x42, 42 PacketNumberLen: protocol.PacketNumberLen4, 43 }).Append(buf.Data, protocol.Version1) 44 Expect(err).ToNot(HaveOccurred()) 45 n := len(buf.Data) 46 buf.Data = append(buf.Data, p...) 47 data := buf.Data 48 sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) 49 _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) 50 data = data[:len(data)+16] 51 sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n]) 52 return receivedPacket{ 53 rcvTime: time.Now(), 54 remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, 55 data: data, 56 buffer: buf, 57 } 58 } 59 60 getInitial := func(destConnID protocol.ConnectionID) receivedPacket { 61 senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} 62 hdr := &wire.Header{ 63 Type: protocol.PacketTypeInitial, 64 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 65 DestConnectionID: destConnID, 66 Version: protocol.Version1, 67 } 68 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 69 p.buffer = getPacketBuffer() 70 p.remoteAddr = senderAddr 71 return p 72 } 73 74 getInitialWithRandomDestConnID := func() receivedPacket { 75 b := make([]byte, 10) 76 _, err := rand.Read(b) 77 Expect(err).ToNot(HaveOccurred()) 78 79 return getInitial(protocol.ParseConnectionID(b)) 80 } 81 82 parseHeader := func(data []byte) *wire.Header { 83 hdr, _, _, err := wire.ParsePacket(data) 84 Expect(err).ToNot(HaveOccurred()) 85 return hdr 86 } 87 88 checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) { 89 replyHdr := parseHeader(b) 90 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 91 Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID)) 92 Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID)) 93 _, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) 94 extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version) 95 Expect(err).ToNot(HaveOccurred()) 96 data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) 97 Expect(err).ToNot(HaveOccurred()) 98 _, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version) 99 Expect(err).ToNot(HaveOccurred()) 100 Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 101 ccf := f.(*wire.ConnectionCloseFrame) 102 Expect(ccf.IsApplicationError).To(BeFalse()) 103 Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode)) 104 Expect(ccf.ReasonPhrase).To(BeEmpty()) 105 } 106 107 BeforeEach(func() { 108 conn = NewMockPacketConn(mockCtrl) 109 conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() 110 wait := make(chan struct{}) 111 conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) { 112 <-wait 113 return 0, nil, errors.New("done") 114 }).MaxTimes(1) 115 conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) error { 116 close(wait) 117 conn.EXPECT().SetReadDeadline(time.Time{}) 118 return nil 119 }).MaxTimes(1) 120 tlsConf = testdata.GetTLSConfig() 121 tlsConf.NextProtos = []string{"proto1"} 122 }) 123 124 It("errors when no tls.Config is given", func() { 125 _, err := ListenAddr("localhost:0", nil, nil) 126 Expect(err).To(HaveOccurred()) 127 Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set")) 128 }) 129 130 It("errors when the Config contains an invalid version", func() { 131 version := protocol.Version(0x1234) 132 _, err := Listen(nil, tlsConf, &Config{Versions: []protocol.Version{version}}) 133 Expect(err).To(MatchError("invalid QUIC version: 0x1234")) 134 }) 135 136 It("fills in default values if options are not set in the Config", func() { 137 ln, err := Listen(conn, tlsConf, &Config{}) 138 Expect(err).ToNot(HaveOccurred()) 139 server := ln.baseServer 140 Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) 141 Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) 142 Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) 143 Expect(server.config.KeepAlivePeriod).To(BeZero()) 144 // stop the listener 145 Expect(ln.Close()).To(Succeed()) 146 }) 147 148 It("setups with the right values", func() { 149 supportedVersions := []protocol.Version{protocol.Version1} 150 config := Config{ 151 Versions: supportedVersions, 152 HandshakeIdleTimeout: 1337 * time.Hour, 153 MaxIdleTimeout: 42 * time.Minute, 154 KeepAlivePeriod: 5 * time.Second, 155 } 156 ln, err := Listen(conn, tlsConf, &config) 157 Expect(err).ToNot(HaveOccurred()) 158 server := ln.baseServer 159 Expect(server.connHandler).ToNot(BeNil()) 160 Expect(server.config.Versions).To(Equal(supportedVersions)) 161 Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) 162 Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) 163 Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) 164 // stop the listener 165 Expect(ln.Close()).To(Succeed()) 166 }) 167 168 It("listens on a given address", func() { 169 addr := "127.0.0.1:13579" 170 ln, err := ListenAddr(addr, tlsConf, &Config{}) 171 Expect(err).ToNot(HaveOccurred()) 172 Expect(ln.Addr().String()).To(Equal(addr)) 173 // stop the listener 174 Expect(ln.Close()).To(Succeed()) 175 }) 176 177 It("errors if given an invalid address", func() { 178 addr := "127.0.0.1" 179 _, err := ListenAddr(addr, tlsConf, &Config{}) 180 Expect(err).To(BeAssignableToTypeOf(&net.AddrError{})) 181 }) 182 183 It("errors if given an invalid address", func() { 184 addr := "1.1.1.1:1111" 185 _, err := ListenAddr(addr, tlsConf, &Config{}) 186 Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) 187 }) 188 189 Context("server accepting connections that completed the handshake", func() { 190 var ( 191 tr *Transport 192 serv *baseServer 193 phm *MockPacketHandlerManager 194 tracer *mocklogging.MockTracer 195 ) 196 197 BeforeEach(func() { 198 var t *logging.Tracer 199 t, tracer = mocklogging.NewMockTracer(mockCtrl) 200 tr = &Transport{Conn: conn, Tracer: t} 201 ln, err := tr.Listen(tlsConf, nil) 202 Expect(err).ToNot(HaveOccurred()) 203 serv = ln.baseServer 204 phm = NewMockPacketHandlerManager(mockCtrl) 205 serv.connHandler = phm 206 }) 207 208 AfterEach(func() { 209 tracer.EXPECT().Close() 210 tr.Close() 211 }) 212 213 Context("handling packets", func() { 214 It("drops Initial packets with a too short connection ID", func() { 215 p := getPacket(&wire.Header{ 216 Type: protocol.PacketTypeInitial, 217 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), 218 Version: serv.config.Versions[0], 219 }, nil) 220 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) 221 serv.handlePacket(p) 222 // make sure there are no Write calls on the packet conn 223 time.Sleep(50 * time.Millisecond) 224 }) 225 226 It("drops too small Initial", func() { 227 p := getPacket(&wire.Header{ 228 Type: protocol.PacketTypeInitial, 229 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), 230 Version: serv.config.Versions[0], 231 }, make([]byte, protocol.MinInitialPacketSize-100)) 232 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) 233 serv.handlePacket(p) 234 // make sure there are no Write calls on the packet conn 235 time.Sleep(50 * time.Millisecond) 236 }) 237 238 It("drops non-Initial packets", func() { 239 p := getPacket(&wire.Header{ 240 Type: protocol.PacketTypeHandshake, 241 Version: serv.config.Versions[0], 242 }, []byte("invalid")) 243 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket) 244 serv.handlePacket(p) 245 // make sure there are no Write calls on the packet conn 246 time.Sleep(50 * time.Millisecond) 247 }) 248 249 It("passes packets to existing connections", func() { 250 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 251 p := getPacket(&wire.Header{ 252 Type: protocol.PacketTypeInitial, 253 DestConnectionID: connID, 254 Version: serv.config.Versions[0], 255 }, make([]byte, protocol.MinInitialPacketSize)) 256 conn := NewMockPacketHandler(mockCtrl) 257 phm.EXPECT().Get(connID).Return(conn, true) 258 handled := make(chan struct{}) 259 conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) 260 serv.handlePacket(p) 261 Eventually(handled).Should(BeClosed()) 262 }) 263 264 It("creates a connection when the token is accepted", func() { 265 serv.verifySourceAddress = func(net.Addr) bool { return true } 266 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 267 retryToken, err := serv.tokenGenerator.NewRetryToken( 268 raddr, 269 protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), 270 protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), 271 ) 272 Expect(err).ToNot(HaveOccurred()) 273 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 274 hdr := &wire.Header{ 275 Type: protocol.PacketTypeInitial, 276 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 277 DestConnectionID: connID, 278 Version: protocol.Version1, 279 Token: retryToken, 280 } 281 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 282 p.remoteAddr = raddr 283 run := make(chan struct{}) 284 var token protocol.StatelessResetToken 285 rand.Read(token[:]) 286 287 var newConnID protocol.ConnectionID 288 conn := NewMockQUICConn(mockCtrl) 289 serv.newConn = func( 290 _ sendConn, 291 _ connRunner, 292 origDestConnID protocol.ConnectionID, 293 retrySrcConnID *protocol.ConnectionID, 294 clientDestConnID protocol.ConnectionID, 295 destConnID protocol.ConnectionID, 296 srcConnID protocol.ConnectionID, 297 _ ConnectionIDGenerator, 298 tokenP protocol.StatelessResetToken, 299 _ *Config, 300 _ *tls.Config, 301 _ *handshake.TokenGenerator, 302 _ bool, 303 _ *logging.ConnectionTracer, 304 _ ConnectionTracingID, 305 _ utils.Logger, 306 _ protocol.Version, 307 ) quicConn { 308 Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))) 309 Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) 310 Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) 311 Expect(destConnID).To(Equal(hdr.SrcConnectionID)) 312 // make sure we're using a server-generated connection ID 313 Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) 314 Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) 315 newConnID = srcConnID 316 Expect(tokenP).To(Equal(token)) 317 conn.EXPECT().handlePacket(p) 318 conn.EXPECT().run().Do(func() error { close(run); return nil }) 319 conn.EXPECT().Context().Return(context.Background()) 320 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 321 return conn 322 } 323 phm.EXPECT().Get(connID) 324 phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token) 325 phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool { 326 Expect(cid).To(Equal(newConnID)) 327 return true 328 }) 329 330 done := make(chan struct{}) 331 go func() { 332 defer GinkgoRecover() 333 serv.handlePacket(p) 334 // the Handshake packet is written by the connection. 335 // Make sure there are no Write calls on the packet conn. 336 time.Sleep(50 * time.Millisecond) 337 close(done) 338 }() 339 // make sure we're using a server-generated connection ID 340 Eventually(run).Should(BeClosed()) 341 Eventually(done).Should(BeClosed()) 342 // shutdown 343 conn.EXPECT().closeWithTransportError(gomock.Any()) 344 }) 345 346 It("sends a Version Negotiation Packet for unsupported versions", func() { 347 srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) 348 destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) 349 packet := getPacket(&wire.Header{ 350 Type: protocol.PacketTypeHandshake, 351 SrcConnectionID: srcConnID, 352 DestConnectionID: destConnID, 353 Version: 0x42, 354 }, make([]byte, protocol.MinUnknownVersionPacketSize)) 355 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 356 packet.remoteAddr = raddr 357 tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.Version) { 358 Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) 359 Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) 360 }) 361 done := make(chan struct{}) 362 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 363 defer close(done) 364 Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) 365 dest, src, versions, err := wire.ParseVersionNegotiationPacket(b) 366 Expect(err).ToNot(HaveOccurred()) 367 Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) 368 Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) 369 Expect(versions).ToNot(ContainElement(protocol.Version(0x42))) 370 return len(b), nil 371 }) 372 serv.handlePacket(packet) 373 Eventually(done).Should(BeClosed()) 374 }) 375 376 It("doesn't send a Version Negotiation packets if sending them is disabled", func() { 377 serv.disableVersionNegotiation = true 378 srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) 379 destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) 380 packet := getPacket(&wire.Header{ 381 Type: protocol.PacketTypeHandshake, 382 SrcConnectionID: srcConnID, 383 DestConnectionID: destConnID, 384 Version: 0x42, 385 }, make([]byte, protocol.MinUnknownVersionPacketSize)) 386 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 387 packet.remoteAddr = raddr 388 done := make(chan struct{}) 389 serv.handlePacket(packet) 390 Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed()) 391 }) 392 393 It("ignores Version Negotiation packets", func() { 394 data := wire.ComposeVersionNegotiation( 395 protocol.ArbitraryLenConnectionID{1, 2, 3, 4}, 396 protocol.ArbitraryLenConnectionID{4, 3, 2, 1}, 397 []protocol.Version{1, 2, 3}, 398 ) 399 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 400 done := make(chan struct{}) 401 tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 402 close(done) 403 }) 404 serv.handlePacket(receivedPacket{ 405 remoteAddr: raddr, 406 data: data, 407 buffer: getPacketBuffer(), 408 }) 409 Eventually(done).Should(BeClosed()) 410 // make sure no other packet is sent 411 time.Sleep(scaleDuration(20 * time.Millisecond)) 412 }) 413 414 It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() { 415 srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) 416 destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) 417 p := getPacket(&wire.Header{ 418 Type: protocol.PacketTypeHandshake, 419 SrcConnectionID: srcConnID, 420 DestConnectionID: destConnID, 421 Version: 0x42, 422 }, make([]byte, protocol.MinUnknownVersionPacketSize-50)) 423 Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize)) 424 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 425 p.remoteAddr = raddr 426 done := make(chan struct{}) 427 tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 428 close(done) 429 }) 430 serv.handlePacket(p) 431 Eventually(done).Should(BeClosed()) 432 // make sure no other packet is sent 433 time.Sleep(scaleDuration(20 * time.Millisecond)) 434 }) 435 436 It("replies with a Retry packet, if a token is required", func() { 437 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 438 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 439 var called bool 440 serv.verifySourceAddress = func(addr net.Addr) bool { 441 Expect(addr).To(Equal(raddr)) 442 called = true 443 return true 444 } 445 hdr := &wire.Header{ 446 Type: protocol.PacketTypeInitial, 447 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 448 DestConnectionID: connID, 449 Version: protocol.Version1, 450 } 451 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 452 packet.remoteAddr = raddr 453 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { 454 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 455 Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) 456 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 457 Expect(replyHdr.Token).ToNot(BeEmpty()) 458 }) 459 done := make(chan struct{}) 460 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 461 defer close(done) 462 replyHdr := parseHeader(b) 463 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 464 Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) 465 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 466 Expect(replyHdr.Token).ToNot(BeEmpty()) 467 Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:])) 468 return len(b), nil 469 }) 470 phm.EXPECT().Get(connID) 471 serv.handlePacket(packet) 472 Eventually(done).Should(BeClosed()) 473 Expect(called).To(BeTrue()) 474 }) 475 476 It("creates a connection, if no token is required", func() { 477 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 478 hdr := &wire.Header{ 479 Type: protocol.PacketTypeInitial, 480 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 481 DestConnectionID: connID, 482 Version: protocol.Version1, 483 } 484 p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 485 run := make(chan struct{}) 486 var token protocol.StatelessResetToken 487 rand.Read(token[:]) 488 489 var newConnID protocol.ConnectionID 490 conn := NewMockQUICConn(mockCtrl) 491 serv.newConn = func( 492 _ sendConn, 493 _ connRunner, 494 origDestConnID protocol.ConnectionID, 495 retrySrcConnID *protocol.ConnectionID, 496 clientDestConnID protocol.ConnectionID, 497 destConnID protocol.ConnectionID, 498 srcConnID protocol.ConnectionID, 499 _ ConnectionIDGenerator, 500 tokenP protocol.StatelessResetToken, 501 _ *Config, 502 _ *tls.Config, 503 _ *handshake.TokenGenerator, 504 _ bool, 505 _ *logging.ConnectionTracer, 506 _ ConnectionTracingID, 507 _ utils.Logger, 508 _ protocol.Version, 509 ) quicConn { 510 Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) 511 Expect(retrySrcConnID).To(BeNil()) 512 Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) 513 Expect(destConnID).To(Equal(hdr.SrcConnectionID)) 514 // make sure we're using a server-generated connection ID 515 Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) 516 Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) 517 newConnID = srcConnID 518 Expect(tokenP).To(Equal(token)) 519 conn.EXPECT().handlePacket(p) 520 conn.EXPECT().run().Do(func() error { close(run); return nil }) 521 conn.EXPECT().Context().Return(context.Background()) 522 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 523 return conn 524 } 525 gomock.InOrder( 526 phm.EXPECT().Get(connID), 527 phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token), 528 phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool { 529 Expect(c).To(Equal(newConnID)) 530 return true 531 }), 532 ) 533 534 done := make(chan struct{}) 535 go func() { 536 defer GinkgoRecover() 537 serv.handlePacket(p) 538 // the Handshake packet is written by the connection 539 // make sure there are no Write calls on the packet conn 540 time.Sleep(50 * time.Millisecond) 541 close(done) 542 }() 543 // make sure we're using a server-generated connection ID 544 Eventually(run).Should(BeClosed()) 545 Eventually(done).Should(BeClosed()) 546 // shutdown 547 conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) 548 }) 549 550 It("drops packets if the receive queue is full", func() { 551 serv.verifySourceAddress = func(net.Addr) bool { return false } 552 553 phm.EXPECT().Get(gomock.Any()).AnyTimes() 554 phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() 555 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() 556 557 acceptConn := make(chan struct{}) 558 var counter atomic.Uint32 559 serv.newConn = func( 560 _ sendConn, 561 runner connRunner, 562 _ protocol.ConnectionID, 563 _ *protocol.ConnectionID, 564 _ protocol.ConnectionID, 565 _ protocol.ConnectionID, 566 _ protocol.ConnectionID, 567 _ ConnectionIDGenerator, 568 _ protocol.StatelessResetToken, 569 _ *Config, 570 _ *tls.Config, 571 _ *handshake.TokenGenerator, 572 _ bool, 573 _ *logging.ConnectionTracer, 574 _ ConnectionTracingID, 575 _ utils.Logger, 576 _ protocol.Version, 577 ) quicConn { 578 <-acceptConn 579 counter.Add(1) 580 conn := NewMockQUICConn(mockCtrl) 581 conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) 582 conn.EXPECT().run().MaxTimes(1) 583 conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) 584 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1) 585 // shutdown 586 conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) 587 return conn 588 } 589 590 p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})) 591 serv.handlePacket(p) 592 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1) 593 var wg sync.WaitGroup 594 for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ { 595 wg.Add(1) 596 go func() { 597 defer GinkgoRecover() 598 defer wg.Done() 599 serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))) 600 }() 601 } 602 wg.Wait() 603 604 close(acceptConn) 605 Eventually( 606 func() uint32 { return counter.Load() }, 607 scaleDuration(100*time.Millisecond), 608 ).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) 609 Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) 610 }) 611 612 It("only creates a single connection for a duplicate Initial", func() { 613 done := make(chan struct{}) 614 serv.newConn = func( 615 _ sendConn, 616 runner connRunner, 617 _ protocol.ConnectionID, 618 _ *protocol.ConnectionID, 619 _ protocol.ConnectionID, 620 _ protocol.ConnectionID, 621 _ protocol.ConnectionID, 622 _ ConnectionIDGenerator, 623 _ protocol.StatelessResetToken, 624 _ *Config, 625 _ *tls.Config, 626 _ *handshake.TokenGenerator, 627 _ bool, 628 _ *logging.ConnectionTracer, 629 _ ConnectionTracingID, 630 _ utils.Logger, 631 _ protocol.Version, 632 ) quicConn { 633 conn := NewMockQUICConn(mockCtrl) 634 conn.EXPECT().handlePacket(gomock.Any()) 635 conn.EXPECT().closeWithTransportError(qerr.ConnectionRefused).Do(func(qerr.TransportErrorCode) { 636 close(done) 637 }) 638 return conn 639 } 640 641 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) 642 p := getInitial(connID) 643 phm.EXPECT().Get(connID) 644 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 645 phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision 646 Expect(serv.handlePacketImpl(p)).To(BeTrue()) 647 Eventually(done).Should(BeClosed()) 648 }) 649 650 It("limits the number of unvalidated handshakes", func() { 651 const limit = 3 652 limiter := rate.NewLimiter(0, limit) 653 serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() } 654 655 phm.EXPECT().Get(gomock.Any()).AnyTimes() 656 phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() 657 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() 658 659 connChan := make(chan *MockQUICConn, 1) 660 var wg sync.WaitGroup 661 wg.Add(limit) 662 done := make(chan struct{}) 663 serv.newConn = func( 664 _ sendConn, 665 runner connRunner, 666 _ protocol.ConnectionID, 667 _ *protocol.ConnectionID, 668 _ protocol.ConnectionID, 669 _ protocol.ConnectionID, 670 _ protocol.ConnectionID, 671 _ ConnectionIDGenerator, 672 _ protocol.StatelessResetToken, 673 _ *Config, 674 _ *tls.Config, 675 _ *handshake.TokenGenerator, 676 _ bool, 677 _ *logging.ConnectionTracer, 678 _ ConnectionTracingID, 679 _ utils.Logger, 680 _ protocol.Version, 681 ) quicConn { 682 conn := <-connChan 683 conn.EXPECT().handlePacket(gomock.Any()) 684 conn.EXPECT().run() 685 conn.EXPECT().Context().Return(context.Background()) 686 conn.EXPECT().HandshakeComplete().DoAndReturn(func() <-chan struct{} { wg.Done(); return done }) 687 return conn 688 } 689 690 // Initiate the maximum number of allowed connection attempts. 691 for i := 0; i < limit; i++ { 692 conn := NewMockQUICConn(mockCtrl) 693 connChan <- conn 694 serv.handlePacket(getInitialWithRandomDestConnID()) 695 } 696 697 // Now initiate another connection attempt. 698 p := getInitialWithRandomDestConnID() 699 tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 700 defer GinkgoRecover() 701 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 702 }) 703 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 704 defer GinkgoRecover() 705 defer close(done) 706 hdr, _, _, err := wire.ParsePacket(b) 707 Expect(err).ToNot(HaveOccurred()) 708 Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) 709 return len(b), nil 710 }) 711 serv.handlePacket(p) 712 Eventually(done).Should(BeClosed()) 713 714 for i := 0; i < limit; i++ { 715 _, err := serv.Accept(context.Background()) 716 Expect(err).ToNot(HaveOccurred()) 717 } 718 wg.Wait() 719 }) 720 }) 721 722 Context("token validation", func() { 723 It("decodes the token from the token field", func() { 724 serv.newConn = func( 725 _ sendConn, 726 _ connRunner, 727 _ protocol.ConnectionID, 728 _ *protocol.ConnectionID, 729 _ protocol.ConnectionID, 730 _ protocol.ConnectionID, 731 _ protocol.ConnectionID, 732 _ ConnectionIDGenerator, 733 _ protocol.StatelessResetToken, 734 _ *Config, 735 _ *tls.Config, 736 _ *handshake.TokenGenerator, 737 _ bool, 738 _ *logging.ConnectionTracer, 739 _ ConnectionTracingID, 740 _ utils.Logger, 741 _ protocol.Version, 742 ) quicConn { 743 c := NewMockQUICConn(mockCtrl) 744 c.EXPECT().handlePacket(gomock.Any()) 745 c.EXPECT().run() 746 c.EXPECT().HandshakeComplete() 747 ctx, cancel := context.WithCancel(context.Background()) 748 cancel() 749 c.EXPECT().Context().Return(ctx) 750 return c 751 } 752 raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} 753 token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) 754 Expect(err).ToNot(HaveOccurred()) 755 packet := getPacket(&wire.Header{ 756 Type: protocol.PacketTypeInitial, 757 Token: token, 758 Version: serv.config.Versions[0], 759 }, make([]byte, protocol.MinInitialPacketSize)) 760 packet.remoteAddr = raddr 761 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) 762 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) 763 764 done := make(chan struct{}) 765 phm.EXPECT().Get(gomock.Any()) 766 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 767 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool { 768 close(done) 769 return true 770 }) 771 phm.EXPECT().Remove(gomock.Any()).AnyTimes() 772 serv.handlePacket(packet) 773 Eventually(done).Should(BeClosed()) 774 }) 775 776 It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { 777 serv.verifySourceAddress = func(net.Addr) bool { return true } 778 token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) 779 Expect(err).ToNot(HaveOccurred()) 780 hdr := &wire.Header{ 781 Type: protocol.PacketTypeInitial, 782 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 783 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 784 Token: token, 785 Version: protocol.Version1, 786 } 787 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 788 packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet 789 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 790 packet.remoteAddr = raddr 791 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 792 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 793 Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 794 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 795 Expect(frames).To(HaveLen(1)) 796 Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 797 ccf := frames[0].(*logging.ConnectionCloseFrame) 798 Expect(ccf.IsApplicationError).To(BeFalse()) 799 Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) 800 }) 801 done := make(chan struct{}) 802 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 803 defer close(done) 804 checkConnectionCloseError(b, hdr, qerr.InvalidToken) 805 return len(b), nil 806 }) 807 phm.EXPECT().Get(gomock.Any()) 808 serv.handlePacket(packet) 809 Eventually(done).Should(BeClosed()) 810 }) 811 812 It("sends an INVALID_TOKEN error, if an expired retry token is received", func() { 813 serv.verifySourceAddress = func(net.Addr) bool { return true } 814 serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout 815 Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond)) 816 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 817 token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) 818 Expect(err).ToNot(HaveOccurred()) 819 time.Sleep(2 * time.Millisecond) // make sure the token is expired 820 hdr := &wire.Header{ 821 Type: protocol.PacketTypeInitial, 822 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 823 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 824 Token: token, 825 Version: protocol.Version1, 826 } 827 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 828 packet.remoteAddr = raddr 829 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 830 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) 831 Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) 832 Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) 833 Expect(frames).To(HaveLen(1)) 834 Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) 835 ccf := frames[0].(*logging.ConnectionCloseFrame) 836 Expect(ccf.IsApplicationError).To(BeFalse()) 837 Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) 838 }) 839 done := make(chan struct{}) 840 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 841 defer close(done) 842 checkConnectionCloseError(b, hdr, qerr.InvalidToken) 843 return len(b), nil 844 }) 845 phm.EXPECT().Get(gomock.Any()) 846 serv.handlePacket(packet) 847 Eventually(done).Should(BeClosed()) 848 }) 849 850 It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() { 851 serv.verifySourceAddress = func(net.Addr) bool { return true } 852 token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}) 853 Expect(err).ToNot(HaveOccurred()) 854 hdr := &wire.Header{ 855 Type: protocol.PacketTypeInitial, 856 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 857 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 858 Token: token, 859 Version: protocol.Version1, 860 } 861 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 862 packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet 863 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 864 packet.remoteAddr = raddr 865 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) 866 done := make(chan struct{}) 867 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 868 defer close(done) 869 replyHdr := parseHeader(b) 870 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 871 return len(b), nil 872 }) 873 phm.EXPECT().Get(gomock.Any()) 874 serv.handlePacket(packet) 875 // make sure there are no Write calls on the packet conn 876 Eventually(done).Should(BeClosed()) 877 }) 878 879 It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() { 880 serv.verifySourceAddress = func(net.Addr) bool { return true } 881 serv.maxTokenAge = time.Millisecond 882 raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 883 token, err := serv.tokenGenerator.NewToken(raddr) 884 Expect(err).ToNot(HaveOccurred()) 885 time.Sleep(2 * time.Millisecond) // make sure the token is expired 886 hdr := &wire.Header{ 887 Type: protocol.PacketTypeInitial, 888 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 889 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 890 Token: token, 891 Version: protocol.Version1, 892 } 893 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 894 packet.remoteAddr = raddr 895 tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { 896 Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) 897 }) 898 done := make(chan struct{}) 899 conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 900 defer close(done) 901 return len(b), nil 902 }) 903 phm.EXPECT().Get(gomock.Any()) 904 serv.handlePacket(packet) 905 Eventually(done).Should(BeClosed()) 906 }) 907 908 It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { 909 token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) 910 Expect(err).ToNot(HaveOccurred()) 911 hdr := &wire.Header{ 912 Type: protocol.PacketTypeInitial, 913 SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), 914 DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 915 Token: token, 916 Version: protocol.Version1, 917 } 918 packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) 919 packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet 920 packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 921 done := make(chan struct{}) 922 tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) 923 phm.EXPECT().Get(gomock.Any()) 924 serv.handlePacket(packet) 925 // make sure there are no Write calls on the packet conn 926 time.Sleep(50 * time.Millisecond) 927 Eventually(done).Should(BeClosed()) 928 }) 929 }) 930 931 Context("accepting connections", func() { 932 It("returns Accept when closed", func() { 933 done := make(chan struct{}) 934 go func() { 935 defer GinkgoRecover() 936 _, err := serv.Accept(context.Background()) 937 Expect(err).To(MatchError(ErrServerClosed)) 938 close(done) 939 }() 940 941 serv.Close() 942 Eventually(done).Should(BeClosed()) 943 }) 944 945 It("returns immediately, if an error occurred before", func() { 946 serv.Close() 947 for i := 0; i < 3; i++ { 948 _, err := serv.Accept(context.Background()) 949 Expect(err).To(MatchError(ErrServerClosed)) 950 } 951 }) 952 953 PIt("closes connection that are still handshaking after Close", func() { 954 serv.Close() 955 956 destroyed := make(chan struct{}) 957 serv.newConn = func( 958 _ sendConn, 959 _ connRunner, 960 _ protocol.ConnectionID, 961 _ *protocol.ConnectionID, 962 _ protocol.ConnectionID, 963 _ protocol.ConnectionID, 964 _ protocol.ConnectionID, 965 _ ConnectionIDGenerator, 966 _ protocol.StatelessResetToken, 967 conf *Config, 968 _ *tls.Config, 969 _ *handshake.TokenGenerator, 970 _ bool, 971 _ *logging.ConnectionTracer, 972 _ ConnectionTracingID, 973 _ utils.Logger, 974 _ protocol.Version, 975 ) quicConn { 976 conn := NewMockQUICConn(mockCtrl) 977 conn.EXPECT().handlePacket(gomock.Any()) 978 conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) }) 979 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 980 conn.EXPECT().run().MaxTimes(1) 981 conn.EXPECT().Context().Return(context.Background()) 982 return conn 983 } 984 phm.EXPECT().Get(gomock.Any()) 985 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 986 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 987 serv.handleInitialImpl( 988 receivedPacket{buffer: getPacketBuffer()}, 989 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 990 ) 991 Eventually(destroyed).Should(BeClosed()) 992 }) 993 994 It("returns when the context is canceled", func() { 995 ctx, cancel := context.WithCancel(context.Background()) 996 done := make(chan struct{}) 997 go func() { 998 defer GinkgoRecover() 999 _, err := serv.Accept(ctx) 1000 Expect(err).To(MatchError("context canceled")) 1001 close(done) 1002 }() 1003 1004 Consistently(done).ShouldNot(BeClosed()) 1005 cancel() 1006 Eventually(done).Should(BeClosed()) 1007 }) 1008 1009 It("uses the config returned by GetConfigClient", func() { 1010 conn := NewMockQUICConn(mockCtrl) 1011 1012 conf := &Config{MaxIncomingStreams: 1234} 1013 serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }}) 1014 done := make(chan struct{}) 1015 go func() { 1016 defer GinkgoRecover() 1017 s, err := serv.Accept(context.Background()) 1018 Expect(err).ToNot(HaveOccurred()) 1019 Expect(s).To(Equal(conn)) 1020 close(done) 1021 }() 1022 1023 handshakeChan := make(chan struct{}) 1024 serv.newConn = func( 1025 _ sendConn, 1026 _ connRunner, 1027 _ protocol.ConnectionID, 1028 _ *protocol.ConnectionID, 1029 _ protocol.ConnectionID, 1030 _ protocol.ConnectionID, 1031 _ protocol.ConnectionID, 1032 _ ConnectionIDGenerator, 1033 _ protocol.StatelessResetToken, 1034 conf *Config, 1035 _ *tls.Config, 1036 _ *handshake.TokenGenerator, 1037 _ bool, 1038 _ *logging.ConnectionTracer, 1039 _ ConnectionTracingID, 1040 _ utils.Logger, 1041 _ protocol.Version, 1042 ) quicConn { 1043 Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234)) 1044 conn.EXPECT().handlePacket(gomock.Any()) 1045 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 1046 conn.EXPECT().run() 1047 conn.EXPECT().Context().Return(context.Background()) 1048 return conn 1049 } 1050 phm.EXPECT().Get(gomock.Any()) 1051 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1052 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1053 serv.handleInitialImpl( 1054 receivedPacket{buffer: getPacketBuffer()}, 1055 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1056 ) 1057 Consistently(done).ShouldNot(BeClosed()) 1058 close(handshakeChan) // complete the handshake 1059 Eventually(done).Should(BeClosed()) 1060 }) 1061 1062 It("rejects a connection attempt when GetConfigClient returns an error", func() { 1063 serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }}) 1064 1065 phm.EXPECT().Get(gomock.Any()) 1066 done := make(chan struct{}) 1067 tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) 1068 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { 1069 defer close(done) 1070 rejectHdr := parseHeader(b) 1071 Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) 1072 return len(b), nil 1073 }) 1074 serv.handleInitialImpl( 1075 receivedPacket{buffer: getPacketBuffer()}, 1076 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1}, 1077 ) 1078 Eventually(done).Should(BeClosed()) 1079 }) 1080 1081 It("accepts new connections when the handshake completes", func() { 1082 conn := NewMockQUICConn(mockCtrl) 1083 1084 done := make(chan struct{}) 1085 go func() { 1086 defer GinkgoRecover() 1087 s, err := serv.Accept(context.Background()) 1088 Expect(err).ToNot(HaveOccurred()) 1089 Expect(s).To(Equal(conn)) 1090 close(done) 1091 }() 1092 1093 handshakeChan := make(chan struct{}) 1094 serv.newConn = func( 1095 _ sendConn, 1096 runner connRunner, 1097 _ protocol.ConnectionID, 1098 _ *protocol.ConnectionID, 1099 _ protocol.ConnectionID, 1100 _ protocol.ConnectionID, 1101 _ protocol.ConnectionID, 1102 _ ConnectionIDGenerator, 1103 _ protocol.StatelessResetToken, 1104 _ *Config, 1105 _ *tls.Config, 1106 _ *handshake.TokenGenerator, 1107 _ bool, 1108 _ *logging.ConnectionTracer, 1109 _ ConnectionTracingID, 1110 _ utils.Logger, 1111 _ protocol.Version, 1112 ) quicConn { 1113 conn.EXPECT().handlePacket(gomock.Any()) 1114 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 1115 conn.EXPECT().run() 1116 conn.EXPECT().Context().Return(context.Background()) 1117 return conn 1118 } 1119 phm.EXPECT().Get(gomock.Any()) 1120 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1121 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1122 serv.handleInitialImpl( 1123 receivedPacket{buffer: getPacketBuffer()}, 1124 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1125 ) 1126 Consistently(done).ShouldNot(BeClosed()) 1127 close(handshakeChan) // complete the handshake 1128 Eventually(done).Should(BeClosed()) 1129 }) 1130 }) 1131 }) 1132 1133 Context("server accepting connections that haven't completed the handshake", func() { 1134 var ( 1135 serv *EarlyListener 1136 phm *MockPacketHandlerManager 1137 ) 1138 1139 BeforeEach(func() { 1140 var err error 1141 serv, err = ListenEarly(conn, tlsConf, nil) 1142 Expect(err).ToNot(HaveOccurred()) 1143 phm = NewMockPacketHandlerManager(mockCtrl) 1144 serv.baseServer.connHandler = phm 1145 }) 1146 1147 AfterEach(func() { 1148 serv.Close() 1149 }) 1150 1151 It("accepts new connections when they become ready", func() { 1152 conn := NewMockQUICConn(mockCtrl) 1153 1154 done := make(chan struct{}) 1155 go func() { 1156 defer GinkgoRecover() 1157 s, err := serv.Accept(context.Background()) 1158 Expect(err).ToNot(HaveOccurred()) 1159 Expect(s).To(Equal(conn)) 1160 close(done) 1161 }() 1162 1163 ready := make(chan struct{}) 1164 serv.baseServer.newConn = func( 1165 _ sendConn, 1166 runner connRunner, 1167 _ protocol.ConnectionID, 1168 _ *protocol.ConnectionID, 1169 _ protocol.ConnectionID, 1170 _ protocol.ConnectionID, 1171 _ protocol.ConnectionID, 1172 _ ConnectionIDGenerator, 1173 _ protocol.StatelessResetToken, 1174 _ *Config, 1175 _ *tls.Config, 1176 _ *handshake.TokenGenerator, 1177 _ bool, 1178 _ *logging.ConnectionTracer, 1179 _ ConnectionTracingID, 1180 _ utils.Logger, 1181 _ protocol.Version, 1182 ) quicConn { 1183 conn.EXPECT().handlePacket(gomock.Any()) 1184 conn.EXPECT().run() 1185 conn.EXPECT().earlyConnReady().Return(ready) 1186 conn.EXPECT().Context().Return(context.Background()) 1187 return conn 1188 } 1189 phm.EXPECT().Get(gomock.Any()) 1190 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1191 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1192 serv.baseServer.handleInitialImpl( 1193 receivedPacket{buffer: getPacketBuffer()}, 1194 &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, 1195 ) 1196 Consistently(done).ShouldNot(BeClosed()) 1197 close(ready) 1198 Eventually(done).Should(BeClosed()) 1199 }) 1200 1201 It("rejects new connection attempts if the accept queue is full", func() { 1202 connChan := make(chan *MockQUICConn, 1) 1203 var wg sync.WaitGroup // to make sure the test fully completes 1204 wg.Add(protocol.MaxAcceptQueueSize) 1205 serv.baseServer.newConn = func( 1206 _ sendConn, 1207 runner connRunner, 1208 _ protocol.ConnectionID, 1209 _ *protocol.ConnectionID, 1210 _ protocol.ConnectionID, 1211 _ protocol.ConnectionID, 1212 _ protocol.ConnectionID, 1213 _ ConnectionIDGenerator, 1214 _ protocol.StatelessResetToken, 1215 _ *Config, 1216 _ *tls.Config, 1217 _ *handshake.TokenGenerator, 1218 _ bool, 1219 _ *logging.ConnectionTracer, 1220 _ ConnectionTracingID, 1221 _ utils.Logger, 1222 _ protocol.Version, 1223 ) quicConn { 1224 ready := make(chan struct{}) 1225 close(ready) 1226 conn := <-connChan 1227 conn.EXPECT().handlePacket(gomock.Any()) 1228 conn.EXPECT().run().Do(func() error { wg.Done(); return nil }) 1229 conn.EXPECT().earlyConnReady().Return(ready) 1230 conn.EXPECT().Context().Return(context.Background()) 1231 return conn 1232 } 1233 1234 phm.EXPECT().Get(gomock.Any()).AnyTimes() 1235 phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) 1236 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize) 1237 for i := 0; i < protocol.MaxAcceptQueueSize; i++ { 1238 conn := NewMockQUICConn(mockCtrl) 1239 connChan <- conn 1240 serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) 1241 } 1242 1243 Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize)) 1244 wg.Wait() 1245 wg.Add(1) 1246 1247 rejected := make(chan struct{}) 1248 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1249 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1250 conn := NewMockQUICConn(mockCtrl) 1251 conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(qerr.TransportErrorCode) { 1252 close(rejected) 1253 }) 1254 connChan <- conn 1255 serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) 1256 Eventually(rejected).Should(BeClosed()) 1257 }) 1258 1259 It("doesn't accept new connections if they were closed in the mean time", func() { 1260 p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) 1261 ctx, cancel := context.WithCancel(context.Background()) 1262 connCreated := make(chan struct{}) 1263 conn := NewMockQUICConn(mockCtrl) 1264 serv.baseServer.newConn = func( 1265 _ sendConn, 1266 runner connRunner, 1267 _ protocol.ConnectionID, 1268 _ *protocol.ConnectionID, 1269 _ protocol.ConnectionID, 1270 _ protocol.ConnectionID, 1271 _ protocol.ConnectionID, 1272 _ ConnectionIDGenerator, 1273 _ protocol.StatelessResetToken, 1274 _ *Config, 1275 _ *tls.Config, 1276 _ *handshake.TokenGenerator, 1277 _ bool, 1278 _ *logging.ConnectionTracer, 1279 _ ConnectionTracingID, 1280 _ utils.Logger, 1281 _ protocol.Version, 1282 ) quicConn { 1283 conn.EXPECT().handlePacket(p) 1284 conn.EXPECT().run() 1285 conn.EXPECT().earlyConnReady() 1286 conn.EXPECT().Context().Return(ctx) 1287 close(connCreated) 1288 return conn 1289 } 1290 1291 phm.EXPECT().Get(gomock.Any()) 1292 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1293 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1294 serv.baseServer.handlePacket(p) 1295 // make sure there are no Write calls on the packet conn 1296 time.Sleep(50 * time.Millisecond) 1297 Eventually(connCreated).Should(BeClosed()) 1298 cancel() 1299 time.Sleep(scaleDuration(200 * time.Millisecond)) 1300 1301 done := make(chan struct{}) 1302 go func() { 1303 defer GinkgoRecover() 1304 serv.Accept(context.Background()) 1305 close(done) 1306 }() 1307 Consistently(done).ShouldNot(BeClosed()) 1308 1309 // make the go routine return 1310 Expect(serv.Close()).To(Succeed()) 1311 Eventually(done).Should(BeClosed()) 1312 }) 1313 }) 1314 1315 Context("0-RTT", func() { 1316 var ( 1317 tr *Transport 1318 serv *baseServer 1319 phm *MockPacketHandlerManager 1320 tracer *mocklogging.MockTracer 1321 ) 1322 1323 BeforeEach(func() { 1324 var t *logging.Tracer 1325 t, tracer = mocklogging.NewMockTracer(mockCtrl) 1326 tr = &Transport{Conn: conn, Tracer: t} 1327 ln, err := tr.ListenEarly(tlsConf, nil) 1328 Expect(err).ToNot(HaveOccurred()) 1329 phm = NewMockPacketHandlerManager(mockCtrl) 1330 serv = ln.baseServer 1331 serv.connHandler = phm 1332 }) 1333 1334 AfterEach(func() { 1335 tracer.EXPECT().Close() 1336 Expect(tr.Close()).To(Succeed()) 1337 }) 1338 1339 It("passes packets to existing connections", func() { 1340 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1341 p := getPacket(&wire.Header{ 1342 Type: protocol.PacketType0RTT, 1343 DestConnectionID: connID, 1344 Version: serv.config.Versions[0], 1345 }, make([]byte, 100)) 1346 conn := NewMockPacketHandler(mockCtrl) 1347 phm.EXPECT().Get(connID).Return(conn, true) 1348 handled := make(chan struct{}) 1349 conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) 1350 serv.handlePacket(p) 1351 Eventually(handled).Should(BeClosed()) 1352 }) 1353 1354 It("queues 0-RTT packets, up to Max0RTTQueueSize", func() { 1355 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1356 1357 var zeroRTTPackets []receivedPacket 1358 1359 for i := 0; i < protocol.Max0RTTQueueLen; i++ { 1360 p := getPacket(&wire.Header{ 1361 Type: protocol.PacketType0RTT, 1362 DestConnectionID: connID, 1363 Version: serv.config.Versions[0], 1364 }, make([]byte, 100+i)) 1365 phm.EXPECT().Get(connID) 1366 serv.handlePacket(p) 1367 zeroRTTPackets = append(zeroRTTPackets, p) 1368 } 1369 1370 // send one more packet, this one should be dropped 1371 p := getPacket(&wire.Header{ 1372 Type: protocol.PacketType0RTT, 1373 DestConnectionID: connID, 1374 Version: serv.config.Versions[0], 1375 }, make([]byte, 200)) 1376 phm.EXPECT().Get(connID) 1377 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) 1378 serv.handlePacket(p) 1379 1380 initial := getPacket(&wire.Header{ 1381 Type: protocol.PacketTypeInitial, 1382 DestConnectionID: connID, 1383 Version: serv.config.Versions[0], 1384 }, make([]byte, protocol.MinInitialPacketSize)) 1385 called := make(chan struct{}) 1386 serv.newConn = func( 1387 _ sendConn, 1388 _ connRunner, 1389 _ protocol.ConnectionID, 1390 _ *protocol.ConnectionID, 1391 _ protocol.ConnectionID, 1392 _ protocol.ConnectionID, 1393 _ protocol.ConnectionID, 1394 _ ConnectionIDGenerator, 1395 _ protocol.StatelessResetToken, 1396 _ *Config, 1397 _ *tls.Config, 1398 _ *handshake.TokenGenerator, 1399 _ bool, 1400 _ *logging.ConnectionTracer, 1401 _ ConnectionTracingID, 1402 _ utils.Logger, 1403 _ protocol.Version, 1404 ) quicConn { 1405 conn := NewMockQUICConn(mockCtrl) 1406 var calls []any 1407 calls = append(calls, conn.EXPECT().handlePacket(initial)) 1408 for _, p := range zeroRTTPackets { 1409 calls = append(calls, conn.EXPECT().handlePacket(p)) 1410 } 1411 gomock.InOrder(calls...) 1412 conn.EXPECT().run() 1413 conn.EXPECT().earlyConnReady() 1414 conn.EXPECT().Context().Return(context.Background()) 1415 close(called) 1416 // shutdown 1417 conn.EXPECT().closeWithTransportError(gomock.Any()) 1418 return conn 1419 } 1420 1421 phm.EXPECT().Get(connID) 1422 phm.EXPECT().GetStatelessResetToken(gomock.Any()) 1423 phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) 1424 serv.handlePacket(initial) 1425 Eventually(called).Should(BeClosed()) 1426 }) 1427 1428 It("limits the number of queues", func() { 1429 for i := 0; i < protocol.Max0RTTQueues; i++ { 1430 b := make([]byte, 16) 1431 rand.Read(b) 1432 connID := protocol.ParseConnectionID(b) 1433 p := getPacket(&wire.Header{ 1434 Type: protocol.PacketType0RTT, 1435 DestConnectionID: connID, 1436 Version: serv.config.Versions[0], 1437 }, make([]byte, 100+i)) 1438 phm.EXPECT().Get(connID) 1439 serv.handlePacket(p) 1440 } 1441 1442 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1443 p := getPacket(&wire.Header{ 1444 Type: protocol.PacketType0RTT, 1445 DestConnectionID: connID, 1446 Version: serv.config.Versions[0], 1447 }, make([]byte, 200)) 1448 phm.EXPECT().Get(connID) 1449 dropped := make(chan struct{}) 1450 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 1451 close(dropped) 1452 }) 1453 serv.handlePacket(p) 1454 Eventually(dropped).Should(BeClosed()) 1455 }) 1456 1457 It("drops queues after a while", func() { 1458 now := time.Now() 1459 1460 connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 1461 p := getPacket(&wire.Header{ 1462 Type: protocol.PacketType0RTT, 1463 DestConnectionID: connID, 1464 Version: serv.config.Versions[0], 1465 }, make([]byte, 200)) 1466 p.rcvTime = now 1467 1468 connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9}) 1469 p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2) 1470 p2 := getPacket(&wire.Header{ 1471 Type: protocol.PacketType0RTT, 1472 DestConnectionID: connID2, 1473 Version: serv.config.Versions[0], 1474 }, make([]byte, 300)) 1475 p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet 1476 1477 dropped1 := make(chan struct{}) 1478 dropped2 := make(chan struct{}) 1479 // need to register the call before handling the packet to avoid race condition 1480 gomock.InOrder( 1481 tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 1482 close(dropped1) 1483 }), 1484 tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 1485 close(dropped2) 1486 }), 1487 ) 1488 1489 phm.EXPECT().Get(connID) 1490 serv.handlePacket(p) 1491 1492 // There's no cleanup Go routine. 1493 // Cleanup is triggered when new packets are received. 1494 1495 phm.EXPECT().Get(connID2) 1496 serv.handlePacket(p2) 1497 // make sure no cleanup is executed 1498 Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed()) 1499 1500 // There's no cleanup Go routine. 1501 // Cleanup is triggered when new packets are received. 1502 connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0}) 1503 p3 := getPacket(&wire.Header{ 1504 Type: protocol.PacketType0RTT, 1505 DestConnectionID: connID3, 1506 Version: serv.config.Versions[0], 1507 }, make([]byte, 200)) 1508 p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup 1509 phm.EXPECT().Get(connID3) 1510 serv.handlePacket(p3) 1511 Eventually(dropped1).Should(BeClosed()) 1512 Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed()) 1513 1514 // make sure the second packet is also cleaned up 1515 connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1}) 1516 p4 := getPacket(&wire.Header{ 1517 Type: protocol.PacketType0RTT, 1518 DestConnectionID: connID4, 1519 Version: serv.config.Versions[0], 1520 }, make([]byte, 200)) 1521 p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup 1522 phm.EXPECT().Get(connID4) 1523 serv.handlePacket(p4) 1524 Eventually(dropped2).Should(BeClosed()) 1525 }) 1526 }) 1527 })