github.com/sagernet/quic-go@v0.43.1-beta.1/ech/transport.go (about) 1 package quic 2 3 import ( 4 "context" 5 "crypto/rand" 6 "errors" 7 "net" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/sagernet/quic-go/internal/protocol" 13 "github.com/sagernet/quic-go/internal/utils" 14 "github.com/sagernet/quic-go/internal/wire" 15 "github.com/sagernet/quic-go/logging" 16 "github.com/sagernet/cloudflare-tls" 17 ) 18 19 var errListenerAlreadySet = errors.New("listener already set") 20 21 // The Transport is the central point to manage incoming and outgoing QUIC connections. 22 // QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple. 23 // This means that a single UDP socket can be used for listening for incoming connections, as well as 24 // for dialing an arbitrary number of outgoing connections. 25 // A Transport handles a single net.PacketConn, and offers a range of configuration options 26 // compared to the simple helper functions like Listen and Dial that this package provides. 27 type Transport struct { 28 // A single net.PacketConn can only be handled by one Transport. 29 // Bad things will happen if passed to multiple Transports. 30 // 31 // A number of optimizations will be enabled if the connections implements the OOBCapablePacketConn interface, 32 // as a *net.UDPConn does. 33 // 1. It enables the Don't Fragment (DF) bit on the IP header. 34 // This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899). 35 // 2. It enables reading of the ECN bits from the IP header. 36 // This allows the remote node to speed up its loss detection and recovery. 37 // 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket. 38 // 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux). 39 // 40 // After passing the connection to the Transport, it's invalid to call ReadFrom or WriteTo on the connection. 41 Conn net.PacketConn 42 43 // The length of the connection ID in bytes. 44 // It can be any value between 1 and 20. 45 // Due to the increased risk of collisions, it is not recommended to use connection IDs shorter than 4 bytes. 46 // If unset, a 4 byte connection ID will be used. 47 ConnectionIDLength int 48 49 // Use for generating new connection IDs. 50 // This allows the application to control of the connection IDs used, 51 // which allows routing / load balancing based on connection IDs. 52 // All Connection IDs returned by the ConnectionIDGenerator MUST 53 // have the same length. 54 ConnectionIDGenerator ConnectionIDGenerator 55 56 // The StatelessResetKey is used to generate stateless reset tokens. 57 // If no key is configured, sending of stateless resets is disabled. 58 // It is highly recommended to configure a stateless reset key, as stateless resets 59 // allow the peer to quickly recover from crashes and reboots of this node. 60 // See section 10.3 of RFC 9000 for details. 61 StatelessResetKey *StatelessResetKey 62 63 // The TokenGeneratorKey is used to encrypt session resumption tokens. 64 // If no key is configured, a random key will be generated. 65 // If multiple servers are authoritative for the same domain, they should use the same key, 66 // see section 8.1.3 of RFC 9000 for details. 67 TokenGeneratorKey *TokenGeneratorKey 68 69 // MaxTokenAge is the maximum age of the resumption token presented during the handshake. 70 // These tokens allow skipping address resumption when resuming a QUIC connection, 71 // and are especially useful when using 0-RTT. 72 // If not set, it defaults to 24 hours. 73 // See section 8.1.3 of RFC 9000 for details. 74 MaxTokenAge time.Duration 75 76 // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. 77 // This can be useful if version information is exchanged out-of-band. 78 // It has no effect for clients. 79 DisableVersionNegotiationPackets bool 80 81 // VerifySourceAddress decides if a connection attempt originating from unvalidated source 82 // addresses first needs to go through source address validation using QUIC's Retry mechanism, 83 // as described in RFC 9000 section 8.1.2. 84 // Note that the address passed to this callback is unvalidated, and might be spoofed in case 85 // of an attack. 86 // Validating the source address adds one additional network roundtrip to the handshake, 87 // and should therefore only be used if a suspiciously high number of incoming connection is recorded. 88 // For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable 89 // implementation of this callback (negating its return value). 90 VerifySourceAddress func(net.Addr) bool 91 92 // A Tracer traces events that don't belong to a single QUIC connection. 93 // Tracer.Close is called when the transport is closed. 94 Tracer *logging.Tracer 95 96 handlerMap packetHandlerManager 97 98 mutex sync.Mutex 99 initOnce sync.Once 100 initErr error 101 102 // Set in init. 103 // If no ConnectionIDGenerator is set, this is the ConnectionIDLength. 104 connIDLen int 105 // Set in init. 106 // If no ConnectionIDGenerator is set, this is set to a default. 107 connIDGenerator ConnectionIDGenerator 108 109 server *baseServer 110 111 conn rawConn 112 113 closeQueue chan closePacket 114 statelessResetQueue chan receivedPacket 115 116 listening chan struct{} // is closed when listen returns 117 closed bool 118 createdConn bool 119 isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial 120 121 readingNonQUICPackets atomic.Bool 122 nonQUICPackets chan receivedPacket 123 124 logger utils.Logger 125 } 126 127 // Listen starts listening for incoming QUIC connections. 128 // There can only be a single listener on any net.PacketConn. 129 // Listen may only be called again after the current Listener was closed. 130 func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) { 131 s, err := t.createServer(tlsConf, conf, false) 132 if err != nil { 133 return nil, err 134 } 135 return &Listener{baseServer: s}, nil 136 } 137 138 // ListenEarly starts listening for incoming QUIC connections. 139 // There can only be a single listener on any net.PacketConn. 140 // Listen may only be called again after the current Listener was closed. 141 func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) { 142 s, err := t.createServer(tlsConf, conf, true) 143 if err != nil { 144 return nil, err 145 } 146 return &EarlyListener{baseServer: s}, nil 147 } 148 149 func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bool) (*baseServer, error) { 150 if tlsConf == nil { 151 return nil, errors.New("quic: tls.Config not set") 152 } 153 if err := validateConfig(conf); err != nil { 154 return nil, err 155 } 156 157 t.mutex.Lock() 158 defer t.mutex.Unlock() 159 160 if t.server != nil { 161 return nil, errListenerAlreadySet 162 } 163 conf = populateConfig(conf) 164 if err := t.init(false); err != nil { 165 return nil, err 166 } 167 s := newServer( 168 t.conn, 169 t.handlerMap, 170 t.connIDGenerator, 171 tlsConf, 172 conf, 173 t.Tracer, 174 t.closeServer, 175 *t.TokenGeneratorKey, 176 t.MaxTokenAge, 177 t.VerifySourceAddress, 178 t.DisableVersionNegotiationPackets, 179 allow0RTT, 180 ) 181 t.server = s 182 return s, nil 183 } 184 185 // Dial dials a new connection to a remote host (not using 0-RTT). 186 func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { 187 return t.dial(ctx, addr, "", tlsConf, conf, false) 188 } 189 190 // DialEarly dials a new connection, attempting to use 0-RTT if possible. 191 func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { 192 return t.dial(ctx, addr, "", tlsConf, conf, true) 193 } 194 195 func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) { 196 if err := validateConfig(conf); err != nil { 197 return nil, err 198 } 199 conf = populateConfig(conf) 200 if err := t.init(t.isSingleUse); err != nil { 201 return nil, err 202 } 203 var onClose func() 204 if t.isSingleUse { 205 onClose = func() { t.Close() } 206 } 207 tlsConf = tlsConf.Clone() 208 // setTLSConfigServerName(tlsConf, addr, host) 209 return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT) 210 } 211 212 func (t *Transport) init(allowZeroLengthConnIDs bool) error { 213 t.initOnce.Do(func() { 214 var conn rawConn 215 if c, ok := t.Conn.(rawConn); ok { 216 conn = c 217 } else { 218 var err error 219 conn, err = wrapConn(t.Conn) 220 if err != nil { 221 t.initErr = err 222 return 223 } 224 } 225 226 t.logger = utils.DefaultLogger // TODO: make this configurable 227 t.conn = conn 228 t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) 229 t.listening = make(chan struct{}) 230 231 t.closeQueue = make(chan closePacket, 4) 232 t.statelessResetQueue = make(chan receivedPacket, 4) 233 if t.TokenGeneratorKey == nil { 234 var key TokenGeneratorKey 235 if _, err := rand.Read(key[:]); err != nil { 236 t.initErr = err 237 return 238 } 239 t.TokenGeneratorKey = &key 240 } 241 242 if t.ConnectionIDGenerator != nil { 243 t.connIDGenerator = t.ConnectionIDGenerator 244 t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen() 245 } else { 246 connIDLen := t.ConnectionIDLength 247 if t.ConnectionIDLength == 0 && !allowZeroLengthConnIDs { 248 connIDLen = protocol.DefaultConnectionIDLength 249 } 250 t.connIDLen = connIDLen 251 t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen} 252 } 253 254 // getMultiplexer().AddConn(t.Conn) 255 go t.listen(conn) 256 go t.runSendQueue() 257 }) 258 return t.initErr 259 } 260 261 // WriteTo sends a packet on the underlying connection. 262 func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { 263 if err := t.init(false); err != nil { 264 return 0, err 265 } 266 return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported) 267 } 268 269 func (t *Transport) enqueueClosePacket(p closePacket) { 270 select { 271 case t.closeQueue <- p: 272 default: 273 // Oops, we're backlogged. 274 // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. 275 } 276 } 277 278 func (t *Transport) runSendQueue() { 279 for { 280 select { 281 case <-t.listening: 282 return 283 case p := <-t.closeQueue: 284 t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported) 285 case p := <-t.statelessResetQueue: 286 t.sendStatelessReset(p) 287 } 288 } 289 } 290 291 // Close closes the underlying connection. 292 // If any listener was started, it will be closed as well. 293 // It is invalid to start new listeners or connections after that. 294 func (t *Transport) Close() error { 295 t.close(errors.New("closing")) 296 if t.createdConn { 297 if err := t.Conn.Close(); err != nil { 298 return err 299 } 300 } else if t.conn != nil { 301 t.conn.SetReadDeadline(time.Now()) 302 defer func() { t.conn.SetReadDeadline(time.Time{}) }() 303 } 304 if t.listening != nil { 305 <-t.listening // wait until listening returns 306 } 307 return nil 308 } 309 310 func (t *Transport) closeServer() { 311 t.mutex.Lock() 312 t.server = nil 313 if t.isSingleUse { 314 t.closed = true 315 } 316 t.mutex.Unlock() 317 if t.createdConn { 318 t.Conn.Close() 319 } 320 if t.isSingleUse { 321 t.conn.SetReadDeadline(time.Now()) 322 defer func() { t.conn.SetReadDeadline(time.Time{}) }() 323 <-t.listening // wait until listening returns 324 } 325 } 326 327 func (t *Transport) close(e error) { 328 t.mutex.Lock() 329 defer t.mutex.Unlock() 330 if t.closed { 331 return 332 } 333 334 if t.handlerMap != nil { 335 t.handlerMap.Close(e) 336 } 337 if t.server != nil { 338 t.server.close(e, false) 339 } 340 if t.Tracer != nil && t.Tracer.Close != nil { 341 t.Tracer.Close() 342 } 343 t.closed = true 344 } 345 346 func (t *Transport) listen(conn rawConn) { 347 defer close(t.listening) 348 // defer getMultiplexer().RemoveConn(t.Conn) 349 350 for { 351 p, err := conn.ReadPacket() 352 //nolint:staticcheck // SA1019 ignore this! 353 // TODO: This code is used to ignore wsa errors on Windows. 354 // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. 355 // See https://github.com/sagernet/quic-go/issues/1737 for details. 356 if nerr, ok := err.(net.Error); ok && nerr.Temporary() { 357 t.mutex.Lock() 358 closed := t.closed 359 t.mutex.Unlock() 360 if closed { 361 return 362 } 363 t.logger.Debugf("Temporary error reading from conn: %w", err) 364 continue 365 } 366 if err != nil { 367 // Windows returns an error when receiving a UDP datagram that doesn't fit into the provided buffer. 368 if isRecvMsgSizeErr(err) { 369 continue 370 } 371 t.close(err) 372 return 373 } 374 t.handlePacket(p) 375 } 376 } 377 378 func (t *Transport) handlePacket(p receivedPacket) { 379 if len(p.data) == 0 { 380 return 381 } 382 if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) { 383 t.handleNonQUICPacket(p) 384 return 385 } 386 connID, err := wire.ParseConnectionID(p.data, t.connIDLen) 387 if err != nil { 388 t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) 389 if t.Tracer != nil && t.Tracer.DroppedPacket != nil { 390 t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) 391 } 392 p.buffer.MaybeRelease() 393 return 394 } 395 396 // If there's a connection associated with the connection ID, pass the packet there. 397 if handler, ok := t.handlerMap.Get(connID); ok { 398 handler.handlePacket(p) 399 return 400 } 401 // RFC 9000 section 10.3.1 requires that the stateless reset detection logic is run for both 402 // packets that cannot be associated with any connections, and for packets that can't be decrypted. 403 // We deviate from the RFC and ignore the latter: If a packet's connection ID is associated with an 404 // existing connection, it is dropped there if if it can't be decrypted. 405 // Stateless resets use random connection IDs, and at reasonable connection ID lengths collisions are 406 // exceedingly rare. In the unlikely event that a stateless reset is misrouted to an existing connection, 407 // it is to be expected that the next stateless reset will be correctly detected. 408 if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset { 409 return 410 } 411 if !wire.IsLongHeaderPacket(p.data[0]) { 412 t.maybeSendStatelessReset(p) 413 return 414 } 415 416 t.mutex.Lock() 417 defer t.mutex.Unlock() 418 if t.server == nil { // no server set 419 t.logger.Debugf("received a packet with an unexpected connection ID %s", connID) 420 return 421 } 422 t.server.handlePacket(p) 423 } 424 425 func (t *Transport) maybeSendStatelessReset(p receivedPacket) { 426 if t.StatelessResetKey == nil { 427 p.buffer.Release() 428 return 429 } 430 431 // Don't send a stateless reset in response to very small packets. 432 // This includes packets that could be stateless resets. 433 if len(p.data) <= protocol.MinStatelessResetSize { 434 p.buffer.Release() 435 return 436 } 437 438 select { 439 case t.statelessResetQueue <- p: 440 default: 441 // it's fine to not send a stateless reset when we're busy 442 p.buffer.Release() 443 } 444 } 445 446 func (t *Transport) sendStatelessReset(p receivedPacket) { 447 defer p.buffer.Release() 448 449 connID, err := wire.ParseConnectionID(p.data, t.connIDLen) 450 if err != nil { 451 t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) 452 return 453 } 454 token := t.handlerMap.GetStatelessResetToken(connID) 455 t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) 456 data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) 457 rand.Read(data) 458 data[0] = (data[0] & 0x7f) | 0x40 459 data = append(data, token[:]...) 460 if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { 461 t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) 462 } 463 } 464 465 func (t *Transport) maybeHandleStatelessReset(data []byte) bool { 466 // stateless resets are always short header packets 467 if wire.IsLongHeaderPacket(data[0]) { 468 return false 469 } 470 if len(data) < 17 /* type byte + 16 bytes for the reset token */ { 471 return false 472 } 473 474 token := *(*protocol.StatelessResetToken)(data[len(data)-16:]) 475 if conn, ok := t.handlerMap.GetByResetToken(token); ok { 476 t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) 477 go conn.destroy(&StatelessResetError{Token: token}) 478 return true 479 } 480 return false 481 } 482 483 func (t *Transport) handleNonQUICPacket(p receivedPacket) { 484 // Strictly speaking, this is racy, 485 // but we only care about receiving packets at some point after ReadNonQUICPacket has been called. 486 if !t.readingNonQUICPackets.Load() { 487 return 488 } 489 select { 490 case t.nonQUICPackets <- p: 491 default: 492 if t.Tracer != nil && t.Tracer.DroppedPacket != nil { 493 t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) 494 } 495 } 496 } 497 498 const maxQueuedNonQUICPackets = 32 499 500 // ReadNonQUICPacket reads non-QUIC packets received on the underlying connection. 501 // The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0. 502 // Note that this is stricter than the detection logic defined in RFC 9443. 503 func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) { 504 if err := t.init(false); err != nil { 505 return 0, nil, err 506 } 507 if !t.readingNonQUICPackets.Load() { 508 t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets) 509 t.readingNonQUICPackets.Store(true) 510 } 511 select { 512 case <-ctx.Done(): 513 return 0, nil, ctx.Err() 514 case p := <-t.nonQUICPackets: 515 n := copy(b, p.data) 516 return n, p.remoteAddr, nil 517 case <-t.listening: 518 return 0, nil, errors.New("closed") 519 } 520 } 521 522 func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) { 523 // If no ServerName is set, infer the ServerName from the host we're connecting to. 524 if tlsConf.ServerName != "" { 525 return 526 } 527 if host == "" { 528 if udpAddr, ok := addr.(*net.UDPAddr); ok { 529 tlsConf.ServerName = udpAddr.IP.String() 530 return 531 } 532 } 533 h, _, err := net.SplitHostPort(host) 534 if err != nil { // This happens if the host doesn't contain a port number. 535 tlsConf.ServerName = host 536 return 537 } 538 tlsConf.ServerName = h 539 }