github.com/sagernet/quic-go@v0.43.1-beta.1/conn_id_generator.go (about)

     1  package quic
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/sagernet/quic-go/internal/protocol"
     7  	"github.com/sagernet/quic-go/internal/qerr"
     8  	"github.com/sagernet/quic-go/internal/utils"
     9  	"github.com/sagernet/quic-go/internal/wire"
    10  )
    11  
    12  type connIDGenerator struct {
    13  	generator  ConnectionIDGenerator
    14  	highestSeq uint64
    15  
    16  	activeSrcConnIDs        map[uint64]protocol.ConnectionID
    17  	initialClientDestConnID *protocol.ConnectionID // nil for the client
    18  
    19  	addConnectionID        func(protocol.ConnectionID)
    20  	getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
    21  	removeConnectionID     func(protocol.ConnectionID)
    22  	retireConnectionID     func(protocol.ConnectionID)
    23  	replaceWithClosed      func([]protocol.ConnectionID, []byte)
    24  	queueControlFrame      func(wire.Frame)
    25  }
    26  
    27  func newConnIDGenerator(
    28  	initialConnectionID protocol.ConnectionID,
    29  	initialClientDestConnID *protocol.ConnectionID, // nil for the client
    30  	addConnectionID func(protocol.ConnectionID),
    31  	getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
    32  	removeConnectionID func(protocol.ConnectionID),
    33  	retireConnectionID func(protocol.ConnectionID),
    34  	replaceWithClosed func([]protocol.ConnectionID, []byte),
    35  	queueControlFrame func(wire.Frame),
    36  	generator ConnectionIDGenerator,
    37  ) *connIDGenerator {
    38  	m := &connIDGenerator{
    39  		generator:              generator,
    40  		activeSrcConnIDs:       make(map[uint64]protocol.ConnectionID),
    41  		addConnectionID:        addConnectionID,
    42  		getStatelessResetToken: getStatelessResetToken,
    43  		removeConnectionID:     removeConnectionID,
    44  		retireConnectionID:     retireConnectionID,
    45  		replaceWithClosed:      replaceWithClosed,
    46  		queueControlFrame:      queueControlFrame,
    47  	}
    48  	m.activeSrcConnIDs[0] = initialConnectionID
    49  	m.initialClientDestConnID = initialClientDestConnID
    50  	return m
    51  }
    52  
    53  func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
    54  	if m.generator.ConnectionIDLen() == 0 {
    55  		return nil
    56  	}
    57  	// The active_connection_id_limit transport parameter is the number of
    58  	// connection IDs the peer will store. This limit includes the connection ID
    59  	// used during the handshake, and the one sent in the preferred_address
    60  	// transport parameter.
    61  	// We currently don't send the preferred_address transport parameter,
    62  	// so we can issue (limit - 1) connection IDs.
    63  	for i := uint64(len(m.activeSrcConnIDs)); i < utils.Min(limit, protocol.MaxIssuedConnectionIDs); i++ {
    64  		if err := m.issueNewConnID(); err != nil {
    65  			return err
    66  		}
    67  	}
    68  	return nil
    69  }
    70  
    71  func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error {
    72  	if seq > m.highestSeq {
    73  		return &qerr.TransportError{
    74  			ErrorCode:    qerr.ProtocolViolation,
    75  			ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq),
    76  		}
    77  	}
    78  	connID, ok := m.activeSrcConnIDs[seq]
    79  	// We might already have deleted this connection ID, if this is a duplicate frame.
    80  	if !ok {
    81  		return nil
    82  	}
    83  	if connID == sentWithDestConnID {
    84  		return &qerr.TransportError{
    85  			ErrorCode:    qerr.ProtocolViolation,
    86  			ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
    87  		}
    88  	}
    89  	m.retireConnectionID(connID)
    90  	delete(m.activeSrcConnIDs, seq)
    91  	// Don't issue a replacement for the initial connection ID.
    92  	if seq == 0 {
    93  		return nil
    94  	}
    95  	return m.issueNewConnID()
    96  }
    97  
    98  func (m *connIDGenerator) issueNewConnID() error {
    99  	connID, err := m.generator.GenerateConnectionID()
   100  	if err != nil {
   101  		return err
   102  	}
   103  	m.activeSrcConnIDs[m.highestSeq+1] = connID
   104  	m.addConnectionID(connID)
   105  	m.queueControlFrame(&wire.NewConnectionIDFrame{
   106  		SequenceNumber:      m.highestSeq + 1,
   107  		ConnectionID:        connID,
   108  		StatelessResetToken: m.getStatelessResetToken(connID),
   109  	})
   110  	m.highestSeq++
   111  	return nil
   112  }
   113  
   114  func (m *connIDGenerator) SetHandshakeComplete() {
   115  	if m.initialClientDestConnID != nil {
   116  		m.retireConnectionID(*m.initialClientDestConnID)
   117  		m.initialClientDestConnID = nil
   118  	}
   119  }
   120  
   121  func (m *connIDGenerator) RemoveAll() {
   122  	if m.initialClientDestConnID != nil {
   123  		m.removeConnectionID(*m.initialClientDestConnID)
   124  	}
   125  	for _, connID := range m.activeSrcConnIDs {
   126  		m.removeConnectionID(connID)
   127  	}
   128  }
   129  
   130  func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
   131  	connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
   132  	if m.initialClientDestConnID != nil {
   133  		connIDs = append(connIDs, *m.initialClientDestConnID)
   134  	}
   135  	for _, connID := range m.activeSrcConnIDs {
   136  		connIDs = append(connIDs, connID)
   137  	}
   138  	m.replaceWithClosed(connIDs, connClose)
   139  }