github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/conn_id_generator.go (about)

     1  package quic
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/apernet/quic-go/internal/protocol"
     7  	"github.com/apernet/quic-go/internal/qerr"
     8  	"github.com/apernet/quic-go/internal/wire"
     9  )
    10  
    11  type connIDGenerator struct {
    12  	generator  ConnectionIDGenerator
    13  	highestSeq uint64
    14  
    15  	activeSrcConnIDs        map[uint64]protocol.ConnectionID
    16  	initialClientDestConnID *protocol.ConnectionID // nil for the client
    17  
    18  	addConnectionID        func(protocol.ConnectionID)
    19  	getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
    20  	removeConnectionID     func(protocol.ConnectionID)
    21  	retireConnectionID     func(protocol.ConnectionID)
    22  	replaceWithClosed      func([]protocol.ConnectionID, []byte)
    23  	queueControlFrame      func(wire.Frame)
    24  }
    25  
    26  func newConnIDGenerator(
    27  	initialConnectionID protocol.ConnectionID,
    28  	initialClientDestConnID *protocol.ConnectionID, // nil for the client
    29  	addConnectionID func(protocol.ConnectionID),
    30  	getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
    31  	removeConnectionID func(protocol.ConnectionID),
    32  	retireConnectionID func(protocol.ConnectionID),
    33  	replaceWithClosed func([]protocol.ConnectionID, []byte),
    34  	queueControlFrame func(wire.Frame),
    35  	generator ConnectionIDGenerator,
    36  ) *connIDGenerator {
    37  	m := &connIDGenerator{
    38  		generator:              generator,
    39  		activeSrcConnIDs:       make(map[uint64]protocol.ConnectionID),
    40  		addConnectionID:        addConnectionID,
    41  		getStatelessResetToken: getStatelessResetToken,
    42  		removeConnectionID:     removeConnectionID,
    43  		retireConnectionID:     retireConnectionID,
    44  		replaceWithClosed:      replaceWithClosed,
    45  		queueControlFrame:      queueControlFrame,
    46  	}
    47  	m.activeSrcConnIDs[0] = initialConnectionID
    48  	m.initialClientDestConnID = initialClientDestConnID
    49  	return m
    50  }
    51  
    52  func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
    53  	if m.generator.ConnectionIDLen() == 0 {
    54  		return nil
    55  	}
    56  	// The active_connection_id_limit transport parameter is the number of
    57  	// connection IDs the peer will store. This limit includes the connection ID
    58  	// used during the handshake, and the one sent in the preferred_address
    59  	// transport parameter.
    60  	// We currently don't send the preferred_address transport parameter,
    61  	// so we can issue (limit - 1) connection IDs.
    62  	for i := uint64(len(m.activeSrcConnIDs)); i < min(limit, protocol.MaxIssuedConnectionIDs); i++ {
    63  		if err := m.issueNewConnID(); err != nil {
    64  			return err
    65  		}
    66  	}
    67  	return nil
    68  }
    69  
    70  func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error {
    71  	if seq > m.highestSeq {
    72  		return &qerr.TransportError{
    73  			ErrorCode:    qerr.ProtocolViolation,
    74  			ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq),
    75  		}
    76  	}
    77  	connID, ok := m.activeSrcConnIDs[seq]
    78  	// We might already have deleted this connection ID, if this is a duplicate frame.
    79  	if !ok {
    80  		return nil
    81  	}
    82  	if connID == sentWithDestConnID {
    83  		return &qerr.TransportError{
    84  			ErrorCode:    qerr.ProtocolViolation,
    85  			ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
    86  		}
    87  	}
    88  	m.retireConnectionID(connID)
    89  	delete(m.activeSrcConnIDs, seq)
    90  	// Don't issue a replacement for the initial connection ID.
    91  	if seq == 0 {
    92  		return nil
    93  	}
    94  	return m.issueNewConnID()
    95  }
    96  
    97  func (m *connIDGenerator) issueNewConnID() error {
    98  	connID, err := m.generator.GenerateConnectionID()
    99  	if err != nil {
   100  		return err
   101  	}
   102  	m.activeSrcConnIDs[m.highestSeq+1] = connID
   103  	m.addConnectionID(connID)
   104  	m.queueControlFrame(&wire.NewConnectionIDFrame{
   105  		SequenceNumber:      m.highestSeq + 1,
   106  		ConnectionID:        connID,
   107  		StatelessResetToken: m.getStatelessResetToken(connID),
   108  	})
   109  	m.highestSeq++
   110  	return nil
   111  }
   112  
   113  func (m *connIDGenerator) SetHandshakeComplete() {
   114  	if m.initialClientDestConnID != nil {
   115  		m.retireConnectionID(*m.initialClientDestConnID)
   116  		m.initialClientDestConnID = nil
   117  	}
   118  }
   119  
   120  func (m *connIDGenerator) RemoveAll() {
   121  	if m.initialClientDestConnID != nil {
   122  		m.removeConnectionID(*m.initialClientDestConnID)
   123  	}
   124  	for _, connID := range m.activeSrcConnIDs {
   125  		m.removeConnectionID(connID)
   126  	}
   127  }
   128  
   129  func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
   130  	connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
   131  	if m.initialClientDestConnID != nil {
   132  		connIDs = append(connIDs, *m.initialClientDestConnID)
   133  	}
   134  	for _, connID := range m.activeSrcConnIDs {
   135  		connIDs = append(connIDs, connID)
   136  	}
   137  	m.replaceWithClosed(connIDs, connClose)
   138  }