github.com/sagernet/quic-go@v0.43.1-beta.1/ech/packet_handler_map.go (about)

     1  package quic
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/rand"
     6  	"crypto/sha256"
     7  	"hash"
     8  	"io"
     9  	"net"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/sagernet/quic-go/internal/protocol"
    14  	"github.com/sagernet/quic-go/internal/utils"
    15  )
    16  
    17  type connCapabilities struct {
    18  	// This connection has the Don't Fragment (DF) bit set.
    19  	// This means it makes to run DPLPMTUD.
    20  	DF bool
    21  	// GSO (Generic Segmentation Offload) supported
    22  	GSO bool
    23  	// ECN (Explicit Congestion Notifications) supported
    24  	ECN bool
    25  }
    26  
    27  // rawConn is a connection that allow reading of a receivedPackeh.
    28  type rawConn interface {
    29  	ReadPacket() (receivedPacket, error)
    30  	// WritePacket writes a packet on the wire.
    31  	// gsoSize is the size of a single packet, or 0 to disable GSO.
    32  	// It is invalid to set gsoSize if capabilities.GSO is not set.
    33  	WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error)
    34  	LocalAddr() net.Addr
    35  	SetReadDeadline(time.Time) error
    36  	io.Closer
    37  
    38  	capabilities() connCapabilities
    39  }
    40  
    41  type closePacket struct {
    42  	payload []byte
    43  	addr    net.Addr
    44  	info    packetInfo
    45  }
    46  
    47  type packetHandlerMap struct {
    48  	mutex       sync.Mutex
    49  	handlers    map[protocol.ConnectionID]packetHandler
    50  	resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
    51  
    52  	closed    bool
    53  	closeChan chan struct{}
    54  
    55  	enqueueClosePacket func(closePacket)
    56  
    57  	deleteRetiredConnsAfter time.Duration
    58  
    59  	statelessResetMutex  sync.Mutex
    60  	statelessResetHasher hash.Hash
    61  
    62  	logger utils.Logger
    63  }
    64  
    65  var _ packetHandlerManager = &packetHandlerMap{}
    66  
    67  func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
    68  	h := &packetHandlerMap{
    69  		closeChan:               make(chan struct{}),
    70  		handlers:                make(map[protocol.ConnectionID]packetHandler),
    71  		resetTokens:             make(map[protocol.StatelessResetToken]packetHandler),
    72  		deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
    73  		enqueueClosePacket:      enqueueClosePacket,
    74  		logger:                  logger,
    75  	}
    76  	if key != nil {
    77  		h.statelessResetHasher = hmac.New(sha256.New, key[:])
    78  	}
    79  	if h.logger.Debug() {
    80  		go h.logUsage()
    81  	}
    82  	return h
    83  }
    84  
    85  func (h *packetHandlerMap) logUsage() {
    86  	ticker := time.NewTicker(2 * time.Second)
    87  	var printedZero bool
    88  	for {
    89  		select {
    90  		case <-h.closeChan:
    91  			return
    92  		case <-ticker.C:
    93  		}
    94  
    95  		h.mutex.Lock()
    96  		numHandlers := len(h.handlers)
    97  		numTokens := len(h.resetTokens)
    98  		h.mutex.Unlock()
    99  		// If the number tracked handlers and tokens is zero, only print it a single time.
   100  		hasZero := numHandlers == 0 && numTokens == 0
   101  		if !hasZero || (hasZero && !printedZero) {
   102  			h.logger.Debugf("Tracking %d connection IDs and %d reset tokens.\n", numHandlers, numTokens)
   103  			printedZero = false
   104  			if hasZero {
   105  				printedZero = true
   106  			}
   107  		}
   108  	}
   109  }
   110  
   111  func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
   112  	h.mutex.Lock()
   113  	defer h.mutex.Unlock()
   114  
   115  	handler, ok := h.handlers[id]
   116  	return handler, ok
   117  }
   118  
   119  func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
   120  	h.mutex.Lock()
   121  	defer h.mutex.Unlock()
   122  
   123  	if _, ok := h.handlers[id]; ok {
   124  		h.logger.Debugf("Not adding connection ID %s, as it already exists.", id)
   125  		return false
   126  	}
   127  	h.handlers[id] = handler
   128  	h.logger.Debugf("Adding connection ID %s.", id)
   129  	return true
   130  }
   131  
   132  func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool {
   133  	h.mutex.Lock()
   134  	defer h.mutex.Unlock()
   135  
   136  	if _, ok := h.handlers[clientDestConnID]; ok {
   137  		h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
   138  		return false
   139  	}
   140  	h.handlers[clientDestConnID] = handler
   141  	h.handlers[newConnID] = handler
   142  	h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
   143  	return true
   144  }
   145  
   146  func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
   147  	h.mutex.Lock()
   148  	delete(h.handlers, id)
   149  	h.mutex.Unlock()
   150  	h.logger.Debugf("Removing connection ID %s.", id)
   151  }
   152  
   153  func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
   154  	h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter)
   155  	time.AfterFunc(h.deleteRetiredConnsAfter, func() {
   156  		h.mutex.Lock()
   157  		delete(h.handlers, id)
   158  		h.mutex.Unlock()
   159  		h.logger.Debugf("Removing connection ID %s after it has been retired.", id)
   160  	})
   161  }
   162  
   163  // ReplaceWithClosed is called when a connection is closed.
   164  // Depending on which side closed the connection, we need to:
   165  // * remote close: absorb delayed packets
   166  // * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
   167  func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte) {
   168  	var handler packetHandler
   169  	if connClosePacket != nil {
   170  		handler = newClosedLocalConn(
   171  			func(addr net.Addr, info packetInfo) {
   172  				h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
   173  			},
   174  			h.logger,
   175  		)
   176  	} else {
   177  		handler = newClosedRemoteConn()
   178  	}
   179  
   180  	h.mutex.Lock()
   181  	for _, id := range ids {
   182  		h.handlers[id] = handler
   183  	}
   184  	h.mutex.Unlock()
   185  	h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids)
   186  
   187  	time.AfterFunc(h.deleteRetiredConnsAfter, func() {
   188  		h.mutex.Lock()
   189  		for _, id := range ids {
   190  			delete(h.handlers, id)
   191  		}
   192  		h.mutex.Unlock()
   193  		h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids)
   194  	})
   195  }
   196  
   197  func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
   198  	h.mutex.Lock()
   199  	h.resetTokens[token] = handler
   200  	h.mutex.Unlock()
   201  }
   202  
   203  func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) {
   204  	h.mutex.Lock()
   205  	delete(h.resetTokens, token)
   206  	h.mutex.Unlock()
   207  }
   208  
   209  func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) {
   210  	h.mutex.Lock()
   211  	defer h.mutex.Unlock()
   212  
   213  	handler, ok := h.resetTokens[token]
   214  	return handler, ok
   215  }
   216  
   217  func (h *packetHandlerMap) Close(e error) {
   218  	h.mutex.Lock()
   219  
   220  	if h.closed {
   221  		h.mutex.Unlock()
   222  		return
   223  	}
   224  
   225  	close(h.closeChan)
   226  
   227  	var wg sync.WaitGroup
   228  	for _, handler := range h.handlers {
   229  		wg.Add(1)
   230  		go func(handler packetHandler) {
   231  			handler.destroy(e)
   232  			wg.Done()
   233  		}(handler)
   234  	}
   235  	h.closed = true
   236  	h.mutex.Unlock()
   237  	wg.Wait()
   238  }
   239  
   240  func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
   241  	var token protocol.StatelessResetToken
   242  	if h.statelessResetHasher == nil {
   243  		// Return a random stateless reset token.
   244  		// This token will be sent in the server's transport parameters.
   245  		// By using a random token, an off-path attacker won't be able to disrupt the connection.
   246  		rand.Read(token[:])
   247  		return token
   248  	}
   249  	h.statelessResetMutex.Lock()
   250  	h.statelessResetHasher.Write(connID.Bytes())
   251  	copy(token[:], h.statelessResetHasher.Sum(nil))
   252  	h.statelessResetHasher.Reset()
   253  	h.statelessResetMutex.Unlock()
   254  	return token
   255  }