github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/transport.go (about) 1 package quic 2 3 import ( 4 "context" 5 "crypto/rand" 6 "crypto/tls" 7 "errors" 8 "math" 9 "net" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol" 15 "github.com/danielpfeifer02/quic-go-prio-packs/internal/utils" 16 "github.com/danielpfeifer02/quic-go-prio-packs/internal/wire" 17 "github.com/danielpfeifer02/quic-go-prio-packs/logging" 18 "github.com/danielpfeifer02/quic-go-prio-packs/priority_setting" 19 ) 20 21 var errListenerAlreadySet = errors.New("listener already set") 22 23 const ( 24 // defaultMaxNumUnvalidatedHandshakes is the default value for Transport.MaxUnvalidatedHandshakes. 25 defaultMaxNumUnvalidatedHandshakes = 128 26 // defaultMaxNumHandshakes is the default value for Transport.MaxHandshakes. 27 // It's not clear how to choose a reasonable value that works for all use cases. 28 // In production, implementations should: 29 // 1. Choose a lower value. 30 // 2. Implement some kind of IP-address based filtering using the Config.GetConfigForClient 31 // callback in order to prevent flooding attacks from a single / small number of IP addresses. 32 defaultMaxNumHandshakes = math.MaxInt32 33 ) 34 35 // The Transport is the central point to manage incoming and outgoing QUIC connections. 36 // QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple. 37 // This means that a single UDP socket can be used for listening for incoming connections, as well as 38 // for dialing an arbitrary number of outgoing connections. 39 // A Transport handles a single net.PacketConn, and offers a range of configuration options 40 // compared to the simple helper functions like Listen and Dial that this package provides. 41 type Transport struct { 42 // A single net.PacketConn can only be handled by one Transport. 43 // Bad things will happen if passed to multiple Transports. 44 // 45 // A number of optimizations will be enabled if the connections implements the OOBCapablePacketConn interface, 46 // as a *net.UDPConn does. 47 // 1. It enables the Don't Fragment (DF) bit on the IP header. 48 // This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899). 49 // 2. It enables reading of the ECN bits from the IP header. 50 // This allows the remote node to speed up its loss detection and recovery. 51 // 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket. 52 // 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux). 53 // 54 // After passing the connection to the Transport, it's invalid to call ReadFrom or WriteTo on the connection. 55 Conn net.PacketConn 56 57 // The length of the connection ID in bytes. 58 // It can be any value between 1 and 20. 59 // Due to the increased risk of collisions, it is not recommended to use connection IDs shorter than 4 bytes. 60 // If unset, a 4 byte connection ID will be used. 61 ConnectionIDLength int 62 63 // Use for generating new connection IDs. 64 // This allows the application to control of the connection IDs used, 65 // which allows routing / load balancing based on connection IDs. 66 // All Connection IDs returned by the ConnectionIDGenerator MUST 67 // have the same length. 68 ConnectionIDGenerator ConnectionIDGenerator 69 70 // The StatelessResetKey is used to generate stateless reset tokens. 71 // If no key is configured, sending of stateless resets is disabled. 72 // It is highly recommended to configure a stateless reset key, as stateless resets 73 // allow the peer to quickly recover from crashes and reboots of this node. 74 // See section 10.3 of RFC 9000 for details. 75 StatelessResetKey *StatelessResetKey 76 77 // The TokenGeneratorKey is used to encrypt session resumption tokens. 78 // If no key is configured, a random key will be generated. 79 // If multiple servers are authoritative for the same domain, they should use the same key, 80 // see section 8.1.3 of RFC 9000 for details. 81 TokenGeneratorKey *TokenGeneratorKey 82 83 // MaxTokenAge is the maximum age of the resumption token presented during the handshake. 84 // These tokens allow skipping address resumption when resuming a QUIC connection, 85 // and are especially useful when using 0-RTT. 86 // If not set, it defaults to 24 hours. 87 // See section 8.1.3 of RFC 9000 for details. 88 MaxTokenAge time.Duration 89 90 // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. 91 // This can be useful if version information is exchanged out-of-band. 92 // It has no effect for clients. 93 DisableVersionNegotiationPackets bool 94 95 // MaxUnvalidatedHandshakes is the maximum number of concurrent incoming QUIC handshakes 96 // originating from unvalidated source addresses. 97 // If the number of handshakes from unvalidated addresses reaches this number, new incoming 98 // connection attempts will need to proof reachability at the respective source address using the 99 // Retry mechanism, as described in RFC 9000 section 8.1.2. 100 // Validating the source address adds one additional network roundtrip to the handshake. 101 // If unset, a default value of 128 will be used. 102 // When set to a negative value, every connection attempt will need to validate the source address. 103 // It does not make sense to set this value higher than MaxHandshakes. 104 MaxUnvalidatedHandshakes int 105 // MaxHandshakes is the maximum number of concurrent incoming handshakes, both from validated 106 // and unvalidated source addresses. 107 // If unset, the number of concurrent handshakes will not be limited. 108 // Applications should choose a reasonable value based on their thread model, and consider 109 // implementing IP-based rate limiting using Config.GetConfigForClient. 110 // If the number of handshakes reaches this number, new connection attempts will be rejected by 111 // terminating the connection attempt using a CONNECTION_REFUSED error. 112 MaxHandshakes int 113 114 // A Tracer traces events that don't belong to a single QUIC connection. 115 // Tracer.Close is called when the transport is closed. 116 Tracer *logging.Tracer 117 118 handlerMap packetHandlerManager 119 120 mutex sync.Mutex 121 initOnce sync.Once 122 initErr error 123 124 // Set in init. 125 // If no ConnectionIDGenerator is set, this is the ConnectionIDLength. 126 connIDLen int 127 // Set in init. 128 // If no ConnectionIDGenerator is set, this is set to a default. 129 connIDGenerator ConnectionIDGenerator 130 131 server *baseServer 132 133 conn rawConn 134 135 closeQueue chan closePacket 136 statelessResetQueue chan receivedPacket 137 138 listening chan struct{} // is closed when listen returns 139 closed bool 140 createdConn bool 141 isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial 142 143 readingNonQUICPackets atomic.Bool 144 nonQUICPackets chan receivedPacket 145 146 logger utils.Logger 147 } 148 149 // Listen starts listening for incoming QUIC connections. 150 // There can only be a single listener on any net.PacketConn. 151 // Listen may only be called again after the current Listener was closed. 152 func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) { 153 s, err := t.createServer(tlsConf, conf, false) 154 if err != nil { 155 return nil, err 156 } 157 return &Listener{baseServer: s}, nil 158 } 159 160 // ListenEarly starts listening for incoming QUIC connections. 161 // There can only be a single listener on any net.PacketConn. 162 // Listen may only be called again after the current Listener was closed. 163 func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) { 164 s, err := t.createServer(tlsConf, conf, true) 165 if err != nil { 166 return nil, err 167 } 168 return &EarlyListener{baseServer: s}, nil 169 } 170 171 func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bool) (*baseServer, error) { 172 if tlsConf == nil { 173 return nil, errors.New("quic: tls.Config not set") 174 } 175 if err := validateConfig(conf); err != nil { 176 return nil, err 177 } 178 179 t.mutex.Lock() 180 defer t.mutex.Unlock() 181 182 if t.server != nil { 183 return nil, errListenerAlreadySet 184 } 185 conf = populateConfig(conf) 186 if err := t.init(false); err != nil { 187 return nil, err 188 } 189 maxUnvalidatedHandshakes := t.MaxUnvalidatedHandshakes 190 if maxUnvalidatedHandshakes == 0 { 191 maxUnvalidatedHandshakes = defaultMaxNumUnvalidatedHandshakes 192 } 193 maxHandshakes := t.MaxHandshakes 194 if maxHandshakes == 0 { 195 maxHandshakes = defaultMaxNumHandshakes 196 } 197 s := newServer( 198 t.conn, 199 t.handlerMap, 200 t.connIDGenerator, 201 tlsConf, 202 conf, 203 t.Tracer, 204 t.closeServer, 205 *t.TokenGeneratorKey, 206 t.MaxTokenAge, 207 maxUnvalidatedHandshakes, 208 maxHandshakes, 209 t.DisableVersionNegotiationPackets, 210 allow0RTT, 211 ) 212 t.server = s 213 return s, nil 214 } 215 216 // Dial dials a new connection to a remote host (not using 0-RTT). 217 func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { 218 return t.dial(ctx, addr, "", tlsConf, conf, false) 219 } 220 221 // DialEarly dials a new connection, attempting to use 0-RTT if possible. 222 func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { 223 return t.dial(ctx, addr, "", tlsConf, conf, true) 224 } 225 226 func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) { 227 if err := validateConfig(conf); err != nil { 228 return nil, err 229 } 230 conf = populateConfig(conf) 231 if err := t.init(t.isSingleUse); err != nil { 232 return nil, err 233 } 234 var onClose func() 235 if t.isSingleUse { 236 onClose = func() { t.Close() } 237 } 238 tlsConf = tlsConf.Clone() 239 setTLSConfigServerName(tlsConf, addr, host) 240 return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT) 241 } 242 243 func (t *Transport) init(allowZeroLengthConnIDs bool) error { 244 t.initOnce.Do(func() { 245 var conn rawConn 246 if c, ok := t.Conn.(rawConn); ok { 247 conn = c 248 } else { 249 var err error 250 conn, err = wrapConn(t.Conn) 251 if err != nil { 252 t.initErr = err 253 return 254 } 255 } 256 257 t.logger = utils.DefaultLogger // TODO: make this configurable 258 t.conn = conn 259 t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) 260 t.listening = make(chan struct{}) 261 262 t.closeQueue = make(chan closePacket, 4) 263 t.statelessResetQueue = make(chan receivedPacket, 4) 264 if t.TokenGeneratorKey == nil { 265 var key TokenGeneratorKey 266 if _, err := rand.Read(key[:]); err != nil { 267 t.initErr = err 268 return 269 } 270 t.TokenGeneratorKey = &key 271 } 272 273 if t.ConnectionIDGenerator != nil { 274 t.connIDGenerator = t.ConnectionIDGenerator 275 t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen() 276 } else { 277 connIDLen := t.ConnectionIDLength 278 if t.ConnectionIDLength == 0 && !allowZeroLengthConnIDs { 279 connIDLen = protocol.DefaultConnectionIDLength 280 } 281 t.connIDLen = connIDLen 282 // PRIO_PACKS_TAG 283 // PriorityCounter has to start at once since if we start at 0 (i.e. low prio) 284 // and we have a high prio packet it is gonna get drupped, which is worse than 285 // not dropping a low prio packet. 286 // Once the connection has more than 1 connection IDs this problem is gone. 287 // TODOME: temporary fix, remove once we have a better solution 288 t.connIDLen = 16 289 t.connIDGenerator = &protocol.PriorityConnectionIDGenerator{ConnLen: t.connIDLen, 290 PriorityCounter: priority_setting.LowestPriority, NumberOfPriorities: priority_setting.NumberOfPriorities} 291 } 292 293 getMultiplexer().AddConn(t.Conn) 294 go t.listen(conn) 295 go t.runSendQueue() 296 }) 297 return t.initErr 298 } 299 300 // WriteTo sends a packet on the underlying connection. 301 func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { 302 if err := t.init(false); err != nil { 303 return 0, err 304 } 305 return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported) 306 } 307 308 func (t *Transport) enqueueClosePacket(p closePacket) { 309 select { 310 case t.closeQueue <- p: 311 default: 312 // Oops, we're backlogged. 313 // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. 314 } 315 } 316 317 func (t *Transport) runSendQueue() { 318 for { 319 select { 320 case <-t.listening: 321 return 322 case p := <-t.closeQueue: 323 t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported) 324 case p := <-t.statelessResetQueue: 325 t.sendStatelessReset(p) 326 } 327 } 328 } 329 330 // Close closes the underlying connection. 331 // If any listener was started, it will be closed as well. 332 // It is invalid to start new listeners or connections after that. 333 func (t *Transport) Close() error { 334 t.close(errors.New("closing")) 335 if t.createdConn { 336 if err := t.Conn.Close(); err != nil { 337 return err 338 } 339 } else if t.conn != nil { 340 t.conn.SetReadDeadline(time.Now()) 341 defer func() { t.conn.SetReadDeadline(time.Time{}) }() 342 } 343 if t.listening != nil { 344 <-t.listening // wait until listening returns 345 } 346 return nil 347 } 348 349 func (t *Transport) closeServer() { 350 t.mutex.Lock() 351 t.server = nil 352 if t.isSingleUse { 353 t.closed = true 354 } 355 t.mutex.Unlock() 356 if t.createdConn { 357 t.Conn.Close() 358 } 359 if t.isSingleUse { 360 t.conn.SetReadDeadline(time.Now()) 361 defer func() { t.conn.SetReadDeadline(time.Time{}) }() 362 <-t.listening // wait until listening returns 363 } 364 } 365 366 func (t *Transport) close(e error) { 367 t.mutex.Lock() 368 defer t.mutex.Unlock() 369 if t.closed { 370 return 371 } 372 373 if t.handlerMap != nil { 374 t.handlerMap.Close(e) 375 } 376 if t.server != nil { 377 t.server.close(e, false) 378 } 379 if t.Tracer != nil && t.Tracer.Close != nil { 380 t.Tracer.Close() 381 } 382 t.closed = true 383 } 384 385 // only print warnings about the UDP receive buffer size once 386 var setBufferWarningOnce sync.Once 387 388 func (t *Transport) listen(conn rawConn) { 389 defer close(t.listening) 390 defer getMultiplexer().RemoveConn(t.Conn) 391 392 for { 393 p, err := conn.ReadPacket() 394 //nolint:staticcheck // SA1019 ignore this! 395 // TODO: This code is used to ignore wsa errors on Windows. 396 // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. 397 // See https://github.com/danielpfeifer02/quic-go-prio-packs/issues/1737 for details. 398 if nerr, ok := err.(net.Error); ok && nerr.Temporary() { 399 t.mutex.Lock() 400 closed := t.closed 401 t.mutex.Unlock() 402 if closed { 403 return 404 } 405 t.logger.Debugf("Temporary error reading from conn: %w", err) 406 continue 407 } 408 if err != nil { 409 // Windows returns an error when receiving a UDP datagram that doesn't fit into the provided buffer. 410 if isRecvMsgSizeErr(err) { 411 continue 412 } 413 t.close(err) 414 return 415 } 416 t.handlePacket(p) 417 } 418 } 419 420 func (t *Transport) handlePacket(p receivedPacket) { 421 if len(p.data) == 0 { 422 return 423 } 424 if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) { 425 t.handleNonQUICPacket(p) 426 return 427 } 428 connID, err := wire.ParseConnectionID(p.data, t.connIDLen) 429 if err != nil { 430 t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) 431 if t.Tracer != nil && t.Tracer.DroppedPacket != nil { 432 t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) 433 } 434 p.buffer.MaybeRelease() 435 return 436 } 437 438 // If there's a connection associated with the connection ID, pass the packet there. 439 if handler, ok := t.handlerMap.Get(connID); ok { 440 handler.handlePacket(p) 441 return 442 } 443 // RFC 9000 section 10.3.1 requires that the stateless reset detection logic is run for both 444 // packets that cannot be associated with any connections, and for packets that can't be decrypted. 445 // We deviate from the RFC and ignore the latter: If a packet's connection ID is associated with an 446 // existing connection, it is dropped there if if it can't be decrypted. 447 // Stateless resets use random connection IDs, and at reasonable connection ID lengths collisions are 448 // exceedingly rare. In the unlikely event that a stateless reset is misrouted to an existing connection, 449 // it is to be expected that the next stateless reset will be correctly detected. 450 if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset { 451 return 452 } 453 if !wire.IsLongHeaderPacket(p.data[0]) { 454 t.maybeSendStatelessReset(p) 455 return 456 } 457 458 t.mutex.Lock() 459 defer t.mutex.Unlock() 460 if t.server == nil { // no server set 461 t.logger.Debugf("received a packet with an unexpected connection ID %s", connID) 462 return 463 } 464 t.server.handlePacket(p) 465 } 466 467 func (t *Transport) maybeSendStatelessReset(p receivedPacket) { 468 if t.StatelessResetKey == nil { 469 p.buffer.Release() 470 return 471 } 472 473 // Don't send a stateless reset in response to very small packets. 474 // This includes packets that could be stateless resets. 475 if len(p.data) <= protocol.MinStatelessResetSize { 476 p.buffer.Release() 477 return 478 } 479 480 select { 481 case t.statelessResetQueue <- p: 482 default: 483 // it's fine to not send a stateless reset when we're busy 484 p.buffer.Release() 485 } 486 } 487 488 func (t *Transport) sendStatelessReset(p receivedPacket) { 489 defer p.buffer.Release() 490 491 connID, err := wire.ParseConnectionID(p.data, t.connIDLen) 492 if err != nil { 493 t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) 494 return 495 } 496 token := t.handlerMap.GetStatelessResetToken(connID) 497 t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) 498 data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) 499 rand.Read(data) 500 data[0] = (data[0] & 0x7f) | 0x40 501 data = append(data, token[:]...) 502 if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { 503 t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) 504 } 505 } 506 507 func (t *Transport) maybeHandleStatelessReset(data []byte) bool { 508 // stateless resets are always short header packets 509 if wire.IsLongHeaderPacket(data[0]) { 510 return false 511 } 512 if len(data) < 17 /* type byte + 16 bytes for the reset token */ { 513 return false 514 } 515 516 token := *(*protocol.StatelessResetToken)(data[len(data)-16:]) 517 if conn, ok := t.handlerMap.GetByResetToken(token); ok { 518 t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) 519 go conn.destroy(&StatelessResetError{Token: token}) 520 return true 521 } 522 return false 523 } 524 525 func (t *Transport) handleNonQUICPacket(p receivedPacket) { 526 // Strictly speaking, this is racy, 527 // but we only care about receiving packets at some point after ReadNonQUICPacket has been called. 528 if !t.readingNonQUICPackets.Load() { 529 return 530 } 531 select { 532 case t.nonQUICPackets <- p: 533 default: 534 if t.Tracer != nil && t.Tracer.DroppedPacket != nil { 535 t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) 536 } 537 } 538 } 539 540 const maxQueuedNonQUICPackets = 32 541 542 // ReadNonQUICPacket reads non-QUIC packets received on the underlying connection. 543 // The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0. 544 // Note that this is stricter than the detection logic defined in RFC 9443. 545 func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) { 546 if err := t.init(false); err != nil { 547 return 0, nil, err 548 } 549 if !t.readingNonQUICPackets.Load() { 550 t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets) 551 t.readingNonQUICPackets.Store(true) 552 } 553 select { 554 case <-ctx.Done(): 555 return 0, nil, ctx.Err() 556 case p := <-t.nonQUICPackets: 557 n := copy(b, p.data) 558 return n, p.remoteAddr, nil 559 case <-t.listening: 560 return 0, nil, errors.New("closed") 561 } 562 } 563 564 func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) { 565 // If no ServerName is set, infer the ServerName from the host we're connecting to. 566 if tlsConf.ServerName != "" { 567 return 568 } 569 if host == "" { 570 if udpAddr, ok := addr.(*net.UDPAddr); ok { 571 tlsConf.ServerName = udpAddr.IP.String() 572 return 573 } 574 } 575 h, _, err := net.SplitHostPort(host) 576 if err != nil { // This happens if the host doesn't contain a port number. 577 tlsConf.ServerName = host 578 return 579 } 580 tlsConf.ServerName = h 581 }