github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/packet_handler_map.go (about)

     1  package quic
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/rand"
     6  	"crypto/sha256"
     7  	"errors"
     8  	"fmt"
     9  	"hash"
    10  	"net"
    11  	"os"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol"
    18  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils"
    19  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/wire"
    20  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/logging"
    21  )
    22  
    23  type zeroRTTQueue struct {
    24  	queue       []*receivedPacket
    25  	retireTimer *time.Timer
    26  }
    27  
    28  var _ packetHandler = &zeroRTTQueue{}
    29  
    30  func (h *zeroRTTQueue) handlePacket(p *receivedPacket) {
    31  	if len(h.queue) < protocol.Max0RTTQueueLen {
    32  		h.queue = append(h.queue, p)
    33  	}
    34  }
    35  func (h *zeroRTTQueue) shutdown()                            {}
    36  func (h *zeroRTTQueue) destroy(error)                        {}
    37  func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient }
    38  func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) {
    39  	for _, p := range h.queue {
    40  		sess.handlePacket(p)
    41  	}
    42  }
    43  
    44  func (h *zeroRTTQueue) Clear() {
    45  	for _, p := range h.queue {
    46  		p.buffer.Release()
    47  	}
    48  }
    49  
    50  type packetHandlerMapEntry struct {
    51  	packetHandler packetHandler
    52  	is0RTTQueue   bool
    53  }
    54  
    55  // The packetHandlerMap stores packetHandlers, identified by connection ID.
    56  // It is used:
    57  // * by the server to store sessions
    58  // * when multiplexing outgoing connections to store clients
    59  type packetHandlerMap struct {
    60  	mutex sync.Mutex
    61  
    62  	conn      connection
    63  	connIDLen int
    64  
    65  	handlers          map[string] /* string(ConnectionID)*/ packetHandlerMapEntry
    66  	resetTokens       map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
    67  	server            unknownPacketHandler
    68  	numZeroRTTEntries int
    69  
    70  	listening chan struct{} // is closed when listen returns
    71  	closed    bool
    72  
    73  	deleteRetiredSessionsAfter time.Duration
    74  	zeroRTTQueueDuration       time.Duration
    75  
    76  	statelessResetEnabled bool
    77  	statelessResetMutex   sync.Mutex
    78  	statelessResetHasher  hash.Hash
    79  
    80  	tracer logging.Tracer
    81  	logger utils.Logger
    82  }
    83  
    84  var _ packetHandlerManager = &packetHandlerMap{}
    85  
    86  func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
    87  	conn, ok := c.(interface{ SetReadBuffer(int) error })
    88  	if !ok {
    89  		return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
    90  	}
    91  	size, err := inspectReadBuffer(c)
    92  	if err != nil {
    93  		return fmt.Errorf("failed to determine receive buffer size: %w", err)
    94  	}
    95  	if size >= protocol.DesiredReceiveBufferSize {
    96  		logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
    97  	}
    98  	if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
    99  		return fmt.Errorf("failed to increase receive buffer size: %w", err)
   100  	}
   101  	newSize, err := inspectReadBuffer(c)
   102  	if err != nil {
   103  		return fmt.Errorf("failed to determine receive buffer size: %w", err)
   104  	}
   105  	if newSize == size {
   106  		return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
   107  	}
   108  	if newSize < protocol.DesiredReceiveBufferSize {
   109  		return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
   110  	}
   111  	logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
   112  	return nil
   113  }
   114  
   115  // only print warnings about the UDP receive buffer size once
   116  var receiveBufferWarningOnce sync.Once
   117  
   118  func newPacketHandlerMap(
   119  	c net.PacketConn,
   120  	connIDLen int,
   121  	statelessResetKey []byte,
   122  	tracer logging.Tracer,
   123  	logger utils.Logger,
   124  ) (packetHandlerManager, error) {
   125  	if err := setReceiveBuffer(c, logger); err != nil {
   126  		if !strings.Contains(err.Error(), "use of closed network connection") {
   127  			receiveBufferWarningOnce.Do(func() {
   128  				if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
   129  					return
   130  				}
   131  				// [Psiphon]
   132  				// Do not emit alert to stderr (was log.Printf).
   133  				logger.Errorf("%s. See https://github.com/lucas-clemente/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
   134  			})
   135  		}
   136  	}
   137  	conn, err := wrapConn(c)
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  	m := &packetHandlerMap{
   142  		conn:                       conn,
   143  		connIDLen:                  connIDLen,
   144  		listening:                  make(chan struct{}),
   145  		handlers:                   make(map[string]packetHandlerMapEntry),
   146  		resetTokens:                make(map[protocol.StatelessResetToken]packetHandler),
   147  		deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout,
   148  		zeroRTTQueueDuration:       protocol.Max0RTTQueueingDuration,
   149  		statelessResetEnabled:      len(statelessResetKey) > 0,
   150  		statelessResetHasher:       hmac.New(sha256.New, statelessResetKey),
   151  		tracer:                     tracer,
   152  		logger:                     logger,
   153  	}
   154  	go m.listen()
   155  
   156  	if logger.Debug() {
   157  		go m.logUsage()
   158  	}
   159  	return m, nil
   160  }
   161  
   162  func (h *packetHandlerMap) logUsage() {
   163  	ticker := time.NewTicker(2 * time.Second)
   164  	var printedZero bool
   165  	for {
   166  		select {
   167  		case <-h.listening:
   168  			return
   169  		case <-ticker.C:
   170  		}
   171  
   172  		h.mutex.Lock()
   173  		numHandlers := len(h.handlers)
   174  		numTokens := len(h.resetTokens)
   175  		h.mutex.Unlock()
   176  		// If the number tracked handlers and tokens is zero, only print it a single time.
   177  		hasZero := numHandlers == 0 && numTokens == 0
   178  		if !hasZero || (hasZero && !printedZero) {
   179  			h.logger.Debugf("Tracking %d connection IDs and %d reset tokens.\n", numHandlers, numTokens)
   180  			printedZero = false
   181  			if hasZero {
   182  				printedZero = true
   183  			}
   184  		}
   185  	}
   186  }
   187  
   188  func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
   189  	h.mutex.Lock()
   190  	defer h.mutex.Unlock()
   191  
   192  	if _, ok := h.handlers[string(id)]; ok {
   193  		h.logger.Debugf("Not adding connection ID %s, as it already exists.", id)
   194  		return false
   195  	}
   196  	h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler}
   197  	h.logger.Debugf("Adding connection ID %s.", id)
   198  	return true
   199  }
   200  
   201  func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool {
   202  	h.mutex.Lock()
   203  	defer h.mutex.Unlock()
   204  
   205  	var q *zeroRTTQueue
   206  	if entry, ok := h.handlers[string(clientDestConnID)]; ok {
   207  		if !entry.is0RTTQueue {
   208  			h.logger.Debugf("Not adding connection ID %s for a new session, as it already exists.", clientDestConnID)
   209  			return false
   210  		}
   211  		q = entry.packetHandler.(*zeroRTTQueue)
   212  		q.retireTimer.Stop()
   213  		h.numZeroRTTEntries--
   214  		if h.numZeroRTTEntries < 0 {
   215  			panic("number of 0-RTT queues < 0")
   216  		}
   217  	}
   218  	sess := fn()
   219  	if q != nil {
   220  		q.EnqueueAll(sess)
   221  	}
   222  	h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess}
   223  	h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess}
   224  	h.logger.Debugf("Adding connection IDs %s and %s for a new session.", clientDestConnID, newConnID)
   225  	return true
   226  }
   227  
   228  func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
   229  	h.mutex.Lock()
   230  	delete(h.handlers, string(id))
   231  	h.mutex.Unlock()
   232  	h.logger.Debugf("Removing connection ID %s.", id)
   233  }
   234  
   235  func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
   236  	h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredSessionsAfter)
   237  	time.AfterFunc(h.deleteRetiredSessionsAfter, func() {
   238  		h.mutex.Lock()
   239  		delete(h.handlers, string(id))
   240  		h.mutex.Unlock()
   241  		h.logger.Debugf("Removing connection ID %s after it has been retired.", id)
   242  	})
   243  }
   244  
   245  func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) {
   246  	h.mutex.Lock()
   247  	h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler}
   248  	h.mutex.Unlock()
   249  	h.logger.Debugf("Replacing session for connection ID %s with a closed session.", id)
   250  
   251  	time.AfterFunc(h.deleteRetiredSessionsAfter, func() {
   252  		h.mutex.Lock()
   253  		handler.shutdown()
   254  		delete(h.handlers, string(id))
   255  		h.mutex.Unlock()
   256  		h.logger.Debugf("Removing connection ID %s for a closed session after it has been retired.", id)
   257  	})
   258  }
   259  
   260  func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
   261  	h.mutex.Lock()
   262  	h.resetTokens[token] = handler
   263  	h.mutex.Unlock()
   264  }
   265  
   266  func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) {
   267  	h.mutex.Lock()
   268  	delete(h.resetTokens, token)
   269  	h.mutex.Unlock()
   270  }
   271  
   272  func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
   273  	h.mutex.Lock()
   274  	h.server = s
   275  	h.mutex.Unlock()
   276  }
   277  
   278  func (h *packetHandlerMap) CloseServer() {
   279  	h.mutex.Lock()
   280  	if h.server == nil {
   281  		h.mutex.Unlock()
   282  		return
   283  	}
   284  	h.server = nil
   285  	var wg sync.WaitGroup
   286  	for _, entry := range h.handlers {
   287  		if entry.packetHandler.getPerspective() == protocol.PerspectiveServer {
   288  			wg.Add(1)
   289  			go func(handler packetHandler) {
   290  				// blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
   291  				handler.shutdown()
   292  				wg.Done()
   293  			}(entry.packetHandler)
   294  		}
   295  	}
   296  	h.mutex.Unlock()
   297  	wg.Wait()
   298  }
   299  
   300  // Destroy closes the underlying connection and waits until listen() has returned.
   301  // It does not close active sessions.
   302  func (h *packetHandlerMap) Destroy() error {
   303  	if err := h.conn.Close(); err != nil {
   304  		return err
   305  	}
   306  	<-h.listening // wait until listening returns
   307  	return nil
   308  }
   309  
   310  func (h *packetHandlerMap) close(e error) error {
   311  	h.mutex.Lock()
   312  	if h.closed {
   313  		h.mutex.Unlock()
   314  		return nil
   315  	}
   316  
   317  	var wg sync.WaitGroup
   318  	for _, entry := range h.handlers {
   319  		wg.Add(1)
   320  		go func(handler packetHandler) {
   321  			handler.destroy(e)
   322  			wg.Done()
   323  		}(entry.packetHandler)
   324  	}
   325  
   326  	if h.server != nil {
   327  		h.server.setCloseError(e)
   328  	}
   329  	h.closed = true
   330  	h.mutex.Unlock()
   331  	wg.Wait()
   332  	return getMultiplexer().RemoveConn(h.conn)
   333  }
   334  
   335  func (h *packetHandlerMap) listen() {
   336  	defer close(h.listening)
   337  	for {
   338  		p, err := h.conn.ReadPacket()
   339  		if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
   340  			h.logger.Debugf("Temporary error reading from conn: %w", err)
   341  			continue
   342  		}
   343  		if err != nil {
   344  			h.close(err)
   345  			return
   346  		}
   347  		h.handlePacket(p)
   348  	}
   349  }
   350  
   351  func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
   352  	connID, err := wire.ParseConnectionID(p.data, h.connIDLen)
   353  	if err != nil {
   354  		h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
   355  		if h.tracer != nil {
   356  			h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
   357  		}
   358  		p.buffer.MaybeRelease()
   359  		return
   360  	}
   361  
   362  	h.mutex.Lock()
   363  	defer h.mutex.Unlock()
   364  
   365  	if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset {
   366  		return
   367  	}
   368  
   369  	if entry, ok := h.handlers[string(connID)]; ok {
   370  		if entry.is0RTTQueue { // only enqueue 0-RTT packets in the 0-RTT queue
   371  			if wire.Is0RTTPacket(p.data) {
   372  				entry.packetHandler.handlePacket(p)
   373  				return
   374  			}
   375  		} else { // existing session
   376  			entry.packetHandler.handlePacket(p)
   377  			return
   378  		}
   379  	}
   380  	if p.data[0]&0x80 == 0 {
   381  		go h.maybeSendStatelessReset(p, connID)
   382  		return
   383  	}
   384  	if h.server == nil { // no server set
   385  		h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
   386  		return
   387  	}
   388  	if wire.Is0RTTPacket(p.data) {
   389  		if h.numZeroRTTEntries >= protocol.Max0RTTQueues {
   390  			return
   391  		}
   392  		h.numZeroRTTEntries++
   393  		queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)}
   394  		h.handlers[string(connID)] = packetHandlerMapEntry{
   395  			packetHandler: queue,
   396  			is0RTTQueue:   true,
   397  		}
   398  		queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() {
   399  			h.mutex.Lock()
   400  			defer h.mutex.Unlock()
   401  			// The entry might have been replaced by an actual session.
   402  			// Only delete it if it's still a 0-RTT queue.
   403  			if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue {
   404  				delete(h.handlers, string(connID))
   405  				h.numZeroRTTEntries--
   406  				if h.numZeroRTTEntries < 0 {
   407  					panic("number of 0-RTT queues < 0")
   408  				}
   409  				entry.packetHandler.(*zeroRTTQueue).Clear()
   410  				if h.logger.Debug() {
   411  					h.logger.Debugf("Removing 0-RTT queue for %s.", connID)
   412  				}
   413  			}
   414  		})
   415  		queue.handlePacket(p)
   416  		return
   417  	}
   418  	h.server.handlePacket(p)
   419  }
   420  
   421  func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
   422  	// stateless resets are always short header packets
   423  	if data[0]&0x80 != 0 {
   424  		return false
   425  	}
   426  	if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
   427  		return false
   428  	}
   429  
   430  	var token protocol.StatelessResetToken
   431  	copy(token[:], data[len(data)-16:])
   432  	if sess, ok := h.resetTokens[token]; ok {
   433  		h.logger.Debugf("Received a stateless reset with token %#x. Closing session.", token)
   434  		go sess.destroy(&StatelessResetError{Token: token})
   435  		return true
   436  	}
   437  	return false
   438  }
   439  
   440  func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
   441  	var token protocol.StatelessResetToken
   442  	if !h.statelessResetEnabled {
   443  		// Return a random stateless reset token.
   444  		// This token will be sent in the server's transport parameters.
   445  		// By using a random token, an off-path attacker won't be able to disrupt the connection.
   446  		rand.Read(token[:])
   447  		return token
   448  	}
   449  	h.statelessResetMutex.Lock()
   450  	h.statelessResetHasher.Write(connID.Bytes())
   451  	copy(token[:], h.statelessResetHasher.Sum(nil))
   452  	h.statelessResetHasher.Reset()
   453  	h.statelessResetMutex.Unlock()
   454  	return token
   455  }
   456  
   457  func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
   458  	defer p.buffer.Release()
   459  	if !h.statelessResetEnabled {
   460  		return
   461  	}
   462  	// Don't send a stateless reset in response to very small packets.
   463  	// This includes packets that could be stateless resets.
   464  	if len(p.data) <= protocol.MinStatelessResetSize {
   465  		return
   466  	}
   467  	token := h.GetStatelessResetToken(connID)
   468  	h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
   469  	data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
   470  	rand.Read(data)
   471  	data[0] = (data[0] & 0x7f) | 0x40
   472  	data = append(data, token[:]...)
   473  	if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
   474  		h.logger.Debugf("Error sending Stateless Reset: %s", err)
   475  	}
   476  }