github.com/tumi8/quic-go@v0.37.4-tum/packet_handler_map.go (about)

     1  package quic
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/rand"
     6  	"crypto/sha256"
     7  	"errors"
     8  	"hash"
     9  	"io"
    10  	"net"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/tumi8/quic-go/noninternal/protocol"
    15  	"github.com/tumi8/quic-go/noninternal/utils"
    16  )
    17  
    18  type connCapabilities struct {
    19  	// This connection has the Don't Fragment (DF) bit set.
    20  	// This means it makes to run DPLPMTUD.
    21  	DF bool
    22  	// GSO (Generic Segmentation Offload) supported
    23  	GSO bool
    24  }
    25  
    26  // rawConn is a connection that allow reading of a receivedPackeh.
    27  type rawConn interface {
    28  	ReadPacket() (receivedPacket, error)
    29  	// The size parameter is used for GSO.
    30  	// If GSO is not support, len(b) must be equal to size.
    31  	WritePacket(b []byte, size uint16, addr net.Addr, oob []byte) (int, error)
    32  	LocalAddr() net.Addr
    33  	SetReadDeadline(time.Time) error
    34  	io.Closer
    35  
    36  	capabilities() connCapabilities
    37  }
    38  
    39  type closePacket struct {
    40  	payload []byte
    41  	addr    net.Addr
    42  	info    packetInfo
    43  }
    44  
    45  type unknownPacketHandler interface {
    46  	handlePacket(receivedPacket)
    47  	setCloseError(error)
    48  }
    49  
    50  var errListenerAlreadySet = errors.New("listener already set")
    51  
    52  type packetHandlerMap struct {
    53  	mutex       sync.Mutex
    54  	handlers    map[protocol.ConnectionID]packetHandler
    55  	resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
    56  
    57  	closed    bool
    58  	closeChan chan struct{}
    59  
    60  	enqueueClosePacket func(closePacket)
    61  
    62  	deleteRetiredConnsAfter time.Duration
    63  
    64  	statelessResetMutex  sync.Mutex
    65  	statelessResetHasher hash.Hash
    66  
    67  	logger utils.Logger
    68  }
    69  
    70  var _ packetHandlerManager = &packetHandlerMap{}
    71  
    72  func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
    73  	h := &packetHandlerMap{
    74  		closeChan:               make(chan struct{}),
    75  		handlers:                make(map[protocol.ConnectionID]packetHandler),
    76  		resetTokens:             make(map[protocol.StatelessResetToken]packetHandler),
    77  		deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
    78  		enqueueClosePacket:      enqueueClosePacket,
    79  		logger:                  logger,
    80  	}
    81  	if key != nil {
    82  		h.statelessResetHasher = hmac.New(sha256.New, key[:])
    83  	}
    84  	if h.logger.Debug() {
    85  		go h.logUsage()
    86  	}
    87  	return h
    88  }
    89  
    90  func (h *packetHandlerMap) logUsage() {
    91  	ticker := time.NewTicker(2 * time.Second)
    92  	var printedZero bool
    93  	for {
    94  		select {
    95  		case <-h.closeChan:
    96  			return
    97  		case <-ticker.C:
    98  		}
    99  
   100  		h.mutex.Lock()
   101  		numHandlers := len(h.handlers)
   102  		numTokens := len(h.resetTokens)
   103  		h.mutex.Unlock()
   104  		// If the number tracked handlers and tokens is zero, only print it a single time.
   105  		hasZero := numHandlers == 0 && numTokens == 0
   106  		if !hasZero || (hasZero && !printedZero) {
   107  			h.logger.Debugf("Tracking %d connection IDs and %d reset tokens.\n", numHandlers, numTokens)
   108  			printedZero = false
   109  			if hasZero {
   110  				printedZero = true
   111  			}
   112  		}
   113  	}
   114  }
   115  
   116  func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
   117  	h.mutex.Lock()
   118  	defer h.mutex.Unlock()
   119  
   120  	handler, ok := h.handlers[id]
   121  	return handler, ok
   122  }
   123  
   124  func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
   125  	h.mutex.Lock()
   126  	defer h.mutex.Unlock()
   127  
   128  	if _, ok := h.handlers[id]; ok {
   129  		h.logger.Debugf("Not adding connection ID %s, as it already exists.", id)
   130  		return false
   131  	}
   132  	h.handlers[id] = handler
   133  	h.logger.Debugf("Adding connection ID %s.", id)
   134  	return true
   135  }
   136  
   137  func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
   138  	h.mutex.Lock()
   139  	defer h.mutex.Unlock()
   140  
   141  	if _, ok := h.handlers[clientDestConnID]; ok {
   142  		h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
   143  		return false
   144  	}
   145  	conn, ok := fn()
   146  	if !ok {
   147  		return false
   148  	}
   149  	h.handlers[clientDestConnID] = conn
   150  	h.handlers[newConnID] = conn
   151  	h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
   152  	return true
   153  }
   154  
   155  func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
   156  	h.mutex.Lock()
   157  	delete(h.handlers, id)
   158  	h.mutex.Unlock()
   159  	h.logger.Debugf("Removing connection ID %s.", id)
   160  }
   161  
   162  func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
   163  	h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter)
   164  	time.AfterFunc(h.deleteRetiredConnsAfter, func() {
   165  		h.mutex.Lock()
   166  		delete(h.handlers, id)
   167  		h.mutex.Unlock()
   168  		h.logger.Debugf("Removing connection ID %s after it has been retired.", id)
   169  	})
   170  }
   171  
   172  // ReplaceWithClosed is called when a connection is closed.
   173  // Depending on which side closed the connection, we need to:
   174  // * remote close: absorb delayed packets
   175  // * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
   176  func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers protocol.Perspective, connClosePacket []byte) {
   177  	var handler packetHandler
   178  	if connClosePacket != nil {
   179  		handler = newClosedLocalConn(
   180  			func(addr net.Addr, info packetInfo) {
   181  				h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
   182  			},
   183  			pers,
   184  			h.logger,
   185  		)
   186  	} else {
   187  		handler = newClosedRemoteConn(pers)
   188  	}
   189  
   190  	h.mutex.Lock()
   191  	for _, id := range ids {
   192  		h.handlers[id] = handler
   193  	}
   194  	h.mutex.Unlock()
   195  	h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids)
   196  
   197  	time.AfterFunc(h.deleteRetiredConnsAfter, func() {
   198  		h.mutex.Lock()
   199  		handler.shutdown()
   200  		for _, id := range ids {
   201  			delete(h.handlers, id)
   202  		}
   203  		h.mutex.Unlock()
   204  		h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids)
   205  	})
   206  }
   207  
   208  func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
   209  	h.mutex.Lock()
   210  	h.resetTokens[token] = handler
   211  	h.mutex.Unlock()
   212  }
   213  
   214  func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) {
   215  	h.mutex.Lock()
   216  	delete(h.resetTokens, token)
   217  	h.mutex.Unlock()
   218  }
   219  
   220  func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) {
   221  	h.mutex.Lock()
   222  	defer h.mutex.Unlock()
   223  
   224  	handler, ok := h.resetTokens[token]
   225  	return handler, ok
   226  }
   227  
   228  func (h *packetHandlerMap) CloseServer() {
   229  	h.mutex.Lock()
   230  	var wg sync.WaitGroup
   231  	for _, handler := range h.handlers {
   232  		if handler.getPerspective() == protocol.PerspectiveServer {
   233  			wg.Add(1)
   234  			go func(handler packetHandler) {
   235  				// blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
   236  				handler.shutdown()
   237  				wg.Done()
   238  			}(handler)
   239  		}
   240  	}
   241  	h.mutex.Unlock()
   242  	wg.Wait()
   243  }
   244  
   245  func (h *packetHandlerMap) Close(e error) {
   246  	h.mutex.Lock()
   247  
   248  	if h.closed {
   249  		h.mutex.Unlock()
   250  		return
   251  	}
   252  
   253  	close(h.closeChan)
   254  
   255  	var wg sync.WaitGroup
   256  	for _, handler := range h.handlers {
   257  		wg.Add(1)
   258  		go func(handler packetHandler) {
   259  			handler.destroy(e)
   260  			wg.Done()
   261  		}(handler)
   262  	}
   263  	h.closed = true
   264  	h.mutex.Unlock()
   265  	wg.Wait()
   266  }
   267  
   268  func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
   269  	var token protocol.StatelessResetToken
   270  	if h.statelessResetHasher == nil {
   271  		// Return a random stateless reset token.
   272  		// This token will be sent in the server's transport parameters.
   273  		// By using a random token, an off-path attacker won't be able to disrupt the connection.
   274  		rand.Read(token[:])
   275  		return token
   276  	}
   277  	h.statelessResetMutex.Lock()
   278  	h.statelessResetHasher.Write(connID.Bytes())
   279  	copy(token[:], h.statelessResetHasher.Sum(nil))
   280  	h.statelessResetHasher.Reset()
   281  	h.statelessResetMutex.Unlock()
   282  	return token
   283  }