github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/packet_handler_map.go (about) 1 package quic 2 3 import ( 4 "crypto/hmac" 5 "crypto/rand" 6 "crypto/sha256" 7 "errors" 8 "fmt" 9 "hash" 10 "net" 11 "os" 12 "strconv" 13 "strings" 14 "sync" 15 "time" 16 17 "github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol" 18 "github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils" 19 "github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/wire" 20 "github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/logging" 21 ) 22 23 type zeroRTTQueue struct { 24 queue []*receivedPacket 25 retireTimer *time.Timer 26 } 27 28 var _ packetHandler = &zeroRTTQueue{} 29 30 func (h *zeroRTTQueue) handlePacket(p *receivedPacket) { 31 if len(h.queue) < protocol.Max0RTTQueueLen { 32 h.queue = append(h.queue, p) 33 } 34 } 35 func (h *zeroRTTQueue) shutdown() {} 36 func (h *zeroRTTQueue) destroy(error) {} 37 func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient } 38 func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) { 39 for _, p := range h.queue { 40 sess.handlePacket(p) 41 } 42 } 43 44 func (h *zeroRTTQueue) Clear() { 45 for _, p := range h.queue { 46 p.buffer.Release() 47 } 48 } 49 50 type packetHandlerMapEntry struct { 51 packetHandler packetHandler 52 is0RTTQueue bool 53 } 54 55 // The packetHandlerMap stores packetHandlers, identified by connection ID. 56 // It is used: 57 // * by the server to store sessions 58 // * when multiplexing outgoing connections to store clients 59 type packetHandlerMap struct { 60 mutex sync.Mutex 61 62 conn connection 63 connIDLen int 64 65 handlers map[string] /* string(ConnectionID)*/ packetHandlerMapEntry 66 resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler 67 server unknownPacketHandler 68 numZeroRTTEntries int 69 70 listening chan struct{} // is closed when listen returns 71 closed bool 72 73 deleteRetiredSessionsAfter time.Duration 74 zeroRTTQueueDuration time.Duration 75 76 statelessResetEnabled bool 77 statelessResetMutex sync.Mutex 78 statelessResetHasher hash.Hash 79 80 tracer logging.Tracer 81 logger utils.Logger 82 } 83 84 var _ packetHandlerManager = &packetHandlerMap{} 85 86 func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { 87 conn, ok := c.(interface{ SetReadBuffer(int) error }) 88 if !ok { 89 return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") 90 } 91 size, err := inspectReadBuffer(c) 92 if err != nil { 93 return fmt.Errorf("failed to determine receive buffer size: %w", err) 94 } 95 if size >= protocol.DesiredReceiveBufferSize { 96 logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) 97 } 98 if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { 99 return fmt.Errorf("failed to increase receive buffer size: %w", err) 100 } 101 newSize, err := inspectReadBuffer(c) 102 if err != nil { 103 return fmt.Errorf("failed to determine receive buffer size: %w", err) 104 } 105 if newSize == size { 106 return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) 107 } 108 if newSize < protocol.DesiredReceiveBufferSize { 109 return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) 110 } 111 logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) 112 return nil 113 } 114 115 // only print warnings about the UDP receive buffer size once 116 var receiveBufferWarningOnce sync.Once 117 118 func newPacketHandlerMap( 119 c net.PacketConn, 120 connIDLen int, 121 statelessResetKey []byte, 122 tracer logging.Tracer, 123 logger utils.Logger, 124 ) (packetHandlerManager, error) { 125 if err := setReceiveBuffer(c, logger); err != nil { 126 if !strings.Contains(err.Error(), "use of closed network connection") { 127 receiveBufferWarningOnce.Do(func() { 128 if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { 129 return 130 } 131 // [Psiphon] 132 // Do not emit alert to stderr (was log.Printf). 133 logger.Errorf("%s. See https://github.com/lucas-clemente/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) 134 }) 135 } 136 } 137 conn, err := wrapConn(c) 138 if err != nil { 139 return nil, err 140 } 141 m := &packetHandlerMap{ 142 conn: conn, 143 connIDLen: connIDLen, 144 listening: make(chan struct{}), 145 handlers: make(map[string]packetHandlerMapEntry), 146 resetTokens: make(map[protocol.StatelessResetToken]packetHandler), 147 deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, 148 zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, 149 statelessResetEnabled: len(statelessResetKey) > 0, 150 statelessResetHasher: hmac.New(sha256.New, statelessResetKey), 151 tracer: tracer, 152 logger: logger, 153 } 154 go m.listen() 155 156 if logger.Debug() { 157 go m.logUsage() 158 } 159 return m, nil 160 } 161 162 func (h *packetHandlerMap) logUsage() { 163 ticker := time.NewTicker(2 * time.Second) 164 var printedZero bool 165 for { 166 select { 167 case <-h.listening: 168 return 169 case <-ticker.C: 170 } 171 172 h.mutex.Lock() 173 numHandlers := len(h.handlers) 174 numTokens := len(h.resetTokens) 175 h.mutex.Unlock() 176 // If the number tracked handlers and tokens is zero, only print it a single time. 177 hasZero := numHandlers == 0 && numTokens == 0 178 if !hasZero || (hasZero && !printedZero) { 179 h.logger.Debugf("Tracking %d connection IDs and %d reset tokens.\n", numHandlers, numTokens) 180 printedZero = false 181 if hasZero { 182 printedZero = true 183 } 184 } 185 } 186 } 187 188 func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { 189 h.mutex.Lock() 190 defer h.mutex.Unlock() 191 192 if _, ok := h.handlers[string(id)]; ok { 193 h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) 194 return false 195 } 196 h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} 197 h.logger.Debugf("Adding connection ID %s.", id) 198 return true 199 } 200 201 func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool { 202 h.mutex.Lock() 203 defer h.mutex.Unlock() 204 205 var q *zeroRTTQueue 206 if entry, ok := h.handlers[string(clientDestConnID)]; ok { 207 if !entry.is0RTTQueue { 208 h.logger.Debugf("Not adding connection ID %s for a new session, as it already exists.", clientDestConnID) 209 return false 210 } 211 q = entry.packetHandler.(*zeroRTTQueue) 212 q.retireTimer.Stop() 213 h.numZeroRTTEntries-- 214 if h.numZeroRTTEntries < 0 { 215 panic("number of 0-RTT queues < 0") 216 } 217 } 218 sess := fn() 219 if q != nil { 220 q.EnqueueAll(sess) 221 } 222 h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess} 223 h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess} 224 h.logger.Debugf("Adding connection IDs %s and %s for a new session.", clientDestConnID, newConnID) 225 return true 226 } 227 228 func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { 229 h.mutex.Lock() 230 delete(h.handlers, string(id)) 231 h.mutex.Unlock() 232 h.logger.Debugf("Removing connection ID %s.", id) 233 } 234 235 func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { 236 h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredSessionsAfter) 237 time.AfterFunc(h.deleteRetiredSessionsAfter, func() { 238 h.mutex.Lock() 239 delete(h.handlers, string(id)) 240 h.mutex.Unlock() 241 h.logger.Debugf("Removing connection ID %s after it has been retired.", id) 242 }) 243 } 244 245 func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) { 246 h.mutex.Lock() 247 h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} 248 h.mutex.Unlock() 249 h.logger.Debugf("Replacing session for connection ID %s with a closed session.", id) 250 251 time.AfterFunc(h.deleteRetiredSessionsAfter, func() { 252 h.mutex.Lock() 253 handler.shutdown() 254 delete(h.handlers, string(id)) 255 h.mutex.Unlock() 256 h.logger.Debugf("Removing connection ID %s for a closed session after it has been retired.", id) 257 }) 258 } 259 260 func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { 261 h.mutex.Lock() 262 h.resetTokens[token] = handler 263 h.mutex.Unlock() 264 } 265 266 func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) { 267 h.mutex.Lock() 268 delete(h.resetTokens, token) 269 h.mutex.Unlock() 270 } 271 272 func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { 273 h.mutex.Lock() 274 h.server = s 275 h.mutex.Unlock() 276 } 277 278 func (h *packetHandlerMap) CloseServer() { 279 h.mutex.Lock() 280 if h.server == nil { 281 h.mutex.Unlock() 282 return 283 } 284 h.server = nil 285 var wg sync.WaitGroup 286 for _, entry := range h.handlers { 287 if entry.packetHandler.getPerspective() == protocol.PerspectiveServer { 288 wg.Add(1) 289 go func(handler packetHandler) { 290 // blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped 291 handler.shutdown() 292 wg.Done() 293 }(entry.packetHandler) 294 } 295 } 296 h.mutex.Unlock() 297 wg.Wait() 298 } 299 300 // Destroy closes the underlying connection and waits until listen() has returned. 301 // It does not close active sessions. 302 func (h *packetHandlerMap) Destroy() error { 303 if err := h.conn.Close(); err != nil { 304 return err 305 } 306 <-h.listening // wait until listening returns 307 return nil 308 } 309 310 func (h *packetHandlerMap) close(e error) error { 311 h.mutex.Lock() 312 if h.closed { 313 h.mutex.Unlock() 314 return nil 315 } 316 317 var wg sync.WaitGroup 318 for _, entry := range h.handlers { 319 wg.Add(1) 320 go func(handler packetHandler) { 321 handler.destroy(e) 322 wg.Done() 323 }(entry.packetHandler) 324 } 325 326 if h.server != nil { 327 h.server.setCloseError(e) 328 } 329 h.closed = true 330 h.mutex.Unlock() 331 wg.Wait() 332 return getMultiplexer().RemoveConn(h.conn) 333 } 334 335 func (h *packetHandlerMap) listen() { 336 defer close(h.listening) 337 for { 338 p, err := h.conn.ReadPacket() 339 if nerr, ok := err.(net.Error); ok && nerr.Temporary() { 340 h.logger.Debugf("Temporary error reading from conn: %w", err) 341 continue 342 } 343 if err != nil { 344 h.close(err) 345 return 346 } 347 h.handlePacket(p) 348 } 349 } 350 351 func (h *packetHandlerMap) handlePacket(p *receivedPacket) { 352 connID, err := wire.ParseConnectionID(p.data, h.connIDLen) 353 if err != nil { 354 h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) 355 if h.tracer != nil { 356 h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) 357 } 358 p.buffer.MaybeRelease() 359 return 360 } 361 362 h.mutex.Lock() 363 defer h.mutex.Unlock() 364 365 if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset { 366 return 367 } 368 369 if entry, ok := h.handlers[string(connID)]; ok { 370 if entry.is0RTTQueue { // only enqueue 0-RTT packets in the 0-RTT queue 371 if wire.Is0RTTPacket(p.data) { 372 entry.packetHandler.handlePacket(p) 373 return 374 } 375 } else { // existing session 376 entry.packetHandler.handlePacket(p) 377 return 378 } 379 } 380 if p.data[0]&0x80 == 0 { 381 go h.maybeSendStatelessReset(p, connID) 382 return 383 } 384 if h.server == nil { // no server set 385 h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) 386 return 387 } 388 if wire.Is0RTTPacket(p.data) { 389 if h.numZeroRTTEntries >= protocol.Max0RTTQueues { 390 return 391 } 392 h.numZeroRTTEntries++ 393 queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)} 394 h.handlers[string(connID)] = packetHandlerMapEntry{ 395 packetHandler: queue, 396 is0RTTQueue: true, 397 } 398 queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { 399 h.mutex.Lock() 400 defer h.mutex.Unlock() 401 // The entry might have been replaced by an actual session. 402 // Only delete it if it's still a 0-RTT queue. 403 if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue { 404 delete(h.handlers, string(connID)) 405 h.numZeroRTTEntries-- 406 if h.numZeroRTTEntries < 0 { 407 panic("number of 0-RTT queues < 0") 408 } 409 entry.packetHandler.(*zeroRTTQueue).Clear() 410 if h.logger.Debug() { 411 h.logger.Debugf("Removing 0-RTT queue for %s.", connID) 412 } 413 } 414 }) 415 queue.handlePacket(p) 416 return 417 } 418 h.server.handlePacket(p) 419 } 420 421 func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { 422 // stateless resets are always short header packets 423 if data[0]&0x80 != 0 { 424 return false 425 } 426 if len(data) < 17 /* type byte + 16 bytes for the reset token */ { 427 return false 428 } 429 430 var token protocol.StatelessResetToken 431 copy(token[:], data[len(data)-16:]) 432 if sess, ok := h.resetTokens[token]; ok { 433 h.logger.Debugf("Received a stateless reset with token %#x. Closing session.", token) 434 go sess.destroy(&StatelessResetError{Token: token}) 435 return true 436 } 437 return false 438 } 439 440 func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { 441 var token protocol.StatelessResetToken 442 if !h.statelessResetEnabled { 443 // Return a random stateless reset token. 444 // This token will be sent in the server's transport parameters. 445 // By using a random token, an off-path attacker won't be able to disrupt the connection. 446 rand.Read(token[:]) 447 return token 448 } 449 h.statelessResetMutex.Lock() 450 h.statelessResetHasher.Write(connID.Bytes()) 451 copy(token[:], h.statelessResetHasher.Sum(nil)) 452 h.statelessResetHasher.Reset() 453 h.statelessResetMutex.Unlock() 454 return token 455 } 456 457 func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { 458 defer p.buffer.Release() 459 if !h.statelessResetEnabled { 460 return 461 } 462 // Don't send a stateless reset in response to very small packets. 463 // This includes packets that could be stateless resets. 464 if len(p.data) <= protocol.MinStatelessResetSize { 465 return 466 } 467 token := h.GetStatelessResetToken(connID) 468 h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) 469 data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) 470 rand.Read(data) 471 data[0] = (data[0] & 0x7f) | 0x40 472 data = append(data, token[:]...) 473 if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { 474 h.logger.Debugf("Error sending Stateless Reset: %s", err) 475 } 476 }