github.com/sagernet/quic-go@v0.43.1-beta.1/ech/conn_id_manager.go (about)

     1  package quic
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/sagernet/quic-go/internal/protocol"
     7  	"github.com/sagernet/quic-go/internal/qerr"
     8  	"github.com/sagernet/quic-go/internal/utils"
     9  	list "github.com/sagernet/quic-go/internal/utils/linkedlist"
    10  	"github.com/sagernet/quic-go/internal/wire"
    11  )
    12  
    13  type newConnID struct {
    14  	SequenceNumber      uint64
    15  	ConnectionID        protocol.ConnectionID
    16  	StatelessResetToken protocol.StatelessResetToken
    17  }
    18  
    19  type connIDManager struct {
    20  	queue list.List[newConnID]
    21  
    22  	handshakeComplete         bool
    23  	activeSequenceNumber      uint64
    24  	highestRetired            uint64
    25  	activeConnectionID        protocol.ConnectionID
    26  	activeStatelessResetToken *protocol.StatelessResetToken
    27  
    28  	// We change the connection ID after sending on average
    29  	// protocol.PacketsPerConnectionID packets. The actual value is randomized
    30  	// hide the packet loss rate from on-path observers.
    31  	rand                   utils.Rand
    32  	packetsSinceLastChange uint32
    33  	packetsPerConnectionID uint32
    34  
    35  	addStatelessResetToken    func(protocol.StatelessResetToken)
    36  	removeStatelessResetToken func(protocol.StatelessResetToken)
    37  	queueControlFrame         func(wire.Frame)
    38  }
    39  
    40  func newConnIDManager(
    41  	initialDestConnID protocol.ConnectionID,
    42  	addStatelessResetToken func(protocol.StatelessResetToken),
    43  	removeStatelessResetToken func(protocol.StatelessResetToken),
    44  	queueControlFrame func(wire.Frame),
    45  ) *connIDManager {
    46  	return &connIDManager{
    47  		activeConnectionID:        initialDestConnID,
    48  		addStatelessResetToken:    addStatelessResetToken,
    49  		removeStatelessResetToken: removeStatelessResetToken,
    50  		queueControlFrame:         queueControlFrame,
    51  	}
    52  }
    53  
    54  func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
    55  	return h.addConnectionID(1, connID, resetToken)
    56  }
    57  
    58  func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
    59  	if err := h.add(f); err != nil {
    60  		return err
    61  	}
    62  	if h.queue.Len() >= protocol.MaxActiveConnectionIDs {
    63  		return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}
    64  	}
    65  	return nil
    66  }
    67  
    68  func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
    69  	// If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active
    70  	// connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately.
    71  	if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired {
    72  		h.queueControlFrame(&wire.RetireConnectionIDFrame{
    73  			SequenceNumber: f.SequenceNumber,
    74  		})
    75  		return nil
    76  	}
    77  
    78  	// Retire elements in the queue.
    79  	// Doesn't retire the active connection ID.
    80  	if f.RetirePriorTo > h.highestRetired {
    81  		var next *list.Element[newConnID]
    82  		for el := h.queue.Front(); el != nil; el = next {
    83  			if el.Value.SequenceNumber >= f.RetirePriorTo {
    84  				break
    85  			}
    86  			next = el.Next()
    87  			h.queueControlFrame(&wire.RetireConnectionIDFrame{
    88  				SequenceNumber: el.Value.SequenceNumber,
    89  			})
    90  			h.queue.Remove(el)
    91  		}
    92  		h.highestRetired = f.RetirePriorTo
    93  	}
    94  
    95  	if f.SequenceNumber == h.activeSequenceNumber {
    96  		return nil
    97  	}
    98  
    99  	if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, f.StatelessResetToken); err != nil {
   100  		return err
   101  	}
   102  
   103  	// Retire the active connection ID, if necessary.
   104  	if h.activeSequenceNumber < f.RetirePriorTo {
   105  		// The queue is guaranteed to have at least one element at this point.
   106  		h.updateConnectionID()
   107  	}
   108  	return nil
   109  }
   110  
   111  func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
   112  	// insert a new element at the end
   113  	if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq {
   114  		h.queue.PushBack(newConnID{
   115  			SequenceNumber:      seq,
   116  			ConnectionID:        connID,
   117  			StatelessResetToken: resetToken,
   118  		})
   119  		return nil
   120  	}
   121  	// insert a new element somewhere in the middle
   122  	for el := h.queue.Front(); el != nil; el = el.Next() {
   123  		if el.Value.SequenceNumber == seq {
   124  			if el.Value.ConnectionID != connID {
   125  				return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq)
   126  			}
   127  			if el.Value.StatelessResetToken != resetToken {
   128  				return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq)
   129  			}
   130  			break
   131  		}
   132  		if el.Value.SequenceNumber > seq {
   133  			h.queue.InsertBefore(newConnID{
   134  				SequenceNumber:      seq,
   135  				ConnectionID:        connID,
   136  				StatelessResetToken: resetToken,
   137  			}, el)
   138  			break
   139  		}
   140  	}
   141  	return nil
   142  }
   143  
   144  func (h *connIDManager) updateConnectionID() {
   145  	h.queueControlFrame(&wire.RetireConnectionIDFrame{
   146  		SequenceNumber: h.activeSequenceNumber,
   147  	})
   148  	h.highestRetired = utils.Max(h.highestRetired, h.activeSequenceNumber)
   149  	if h.activeStatelessResetToken != nil {
   150  		h.removeStatelessResetToken(*h.activeStatelessResetToken)
   151  	}
   152  
   153  	front := h.queue.Remove(h.queue.Front())
   154  	h.activeSequenceNumber = front.SequenceNumber
   155  	h.activeConnectionID = front.ConnectionID
   156  	h.activeStatelessResetToken = &front.StatelessResetToken
   157  	h.packetsSinceLastChange = 0
   158  	h.packetsPerConnectionID = protocol.PacketsPerConnectionID/2 + uint32(h.rand.Int31n(protocol.PacketsPerConnectionID))
   159  	h.addStatelessResetToken(*h.activeStatelessResetToken)
   160  }
   161  
   162  func (h *connIDManager) Close() {
   163  	if h.activeStatelessResetToken != nil {
   164  		h.removeStatelessResetToken(*h.activeStatelessResetToken)
   165  	}
   166  }
   167  
   168  // is called when the server performs a Retry
   169  // and when the server changes the connection ID in the first Initial sent
   170  func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {
   171  	if h.activeSequenceNumber != 0 {
   172  		panic("expected first connection ID to have sequence number 0")
   173  	}
   174  	h.activeConnectionID = newConnID
   175  }
   176  
   177  // is called when the server provides a stateless reset token in the transport parameters
   178  func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) {
   179  	if h.activeSequenceNumber != 0 {
   180  		panic("expected first connection ID to have sequence number 0")
   181  	}
   182  	h.activeStatelessResetToken = &token
   183  	h.addStatelessResetToken(token)
   184  }
   185  
   186  func (h *connIDManager) SentPacket() {
   187  	h.packetsSinceLastChange++
   188  }
   189  
   190  func (h *connIDManager) shouldUpdateConnID() bool {
   191  	if !h.handshakeComplete {
   192  		return false
   193  	}
   194  	// initiate the first change as early as possible (after handshake completion)
   195  	if h.queue.Len() > 0 && h.activeSequenceNumber == 0 {
   196  		return true
   197  	}
   198  	// For later changes, only change if
   199  	// 1. The queue of connection IDs is filled more than 50%.
   200  	// 2. We sent at least PacketsPerConnectionID packets
   201  	return 2*h.queue.Len() >= protocol.MaxActiveConnectionIDs &&
   202  		h.packetsSinceLastChange >= h.packetsPerConnectionID
   203  }
   204  
   205  func (h *connIDManager) Get() protocol.ConnectionID {
   206  	if h.shouldUpdateConnID() {
   207  		h.updateConnectionID()
   208  	}
   209  	return h.activeConnectionID
   210  }
   211  
   212  func (h *connIDManager) SetHandshakeComplete() {
   213  	h.handshakeComplete = true
   214  }