github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/conn_id_generator.go (about)

     1  package quic
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
     7  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr"
     8  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/wire"
     9  )
    10  
    11  type connIDGenerator struct {
    12  	generator  ConnectionIDGenerator
    13  	highestSeq uint64
    14  
    15  	activeSrcConnIDs        map[uint64]protocol.ConnectionID
    16  	initialClientDestConnID *protocol.ConnectionID // nil for the client
    17  
    18  	addConnectionID        func(protocol.ConnectionID)
    19  	getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
    20  	removeConnectionID     func(protocol.ConnectionID)
    21  	retireConnectionID     func(protocol.ConnectionID)
    22  	replaceWithClosed      func([]protocol.ConnectionID, []byte)
    23  	queueControlFrame      func(wire.Frame)
    24  }
    25  
    26  func newConnIDGenerator(
    27  	initialConnectionID protocol.ConnectionID,
    28  	initialClientDestConnID *protocol.ConnectionID, // nil for the client
    29  	addConnectionID func(protocol.ConnectionID),
    30  	getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
    31  	removeConnectionID func(protocol.ConnectionID),
    32  	retireConnectionID func(protocol.ConnectionID),
    33  	replaceWithClosed func([]protocol.ConnectionID, []byte),
    34  	queueControlFrame func(wire.Frame),
    35  	generator ConnectionIDGenerator,
    36  ) *connIDGenerator {
    37  	m := &connIDGenerator{
    38  		generator:              generator,
    39  		activeSrcConnIDs:       make(map[uint64]protocol.ConnectionID),
    40  		addConnectionID:        addConnectionID,
    41  		getStatelessResetToken: getStatelessResetToken,
    42  		removeConnectionID:     removeConnectionID,
    43  		retireConnectionID:     retireConnectionID,
    44  		replaceWithClosed:      replaceWithClosed,
    45  		queueControlFrame:      queueControlFrame,
    46  	}
    47  	m.activeSrcConnIDs[0] = initialConnectionID
    48  	m.initialClientDestConnID = initialClientDestConnID
    49  	return m
    50  }
    51  
    52  func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
    53  	if m.generator.ConnectionIDLen() == 0 {
    54  		return nil
    55  	}
    56  
    57  	// PRIO_PACKS_TAG
    58  	if _, ok := m.generator.(*protocol.PriorityConnectionIDGenerator); ok {
    59  		numberOfPriorities := m.generator.(*protocol.PriorityConnectionIDGenerator).NumberOfPriorities
    60  		// < 	ensures that for each priority we have at least one connection ID
    61  		// != 	would ensure that there is *exactly* one connection ID for each priority
    62  		if limit < uint64(numberOfPriorities) {
    63  			fmt.Println("WARNING: active_connection_id_limit is smaller than the number of priorities. Set to the number of priorities.")
    64  			limit = uint64(numberOfPriorities)
    65  		}
    66  		if protocol.MaxIssuedConnectionIDs < uint64(numberOfPriorities) {
    67  			panic("MaxIssuedConnectionIDs is smaller than the number of priorities. Choose a smaller number of priorities or increase MaxIssuedConnectionIDs.")
    68  		}
    69  	}
    70  
    71  	// The active_connection_id_limit transport parameter is the number of
    72  	// connection IDs the peer will store. This limit includes the connection ID
    73  	// used during the handshake, and the one sent in the preferred_address
    74  	// transport parameter.
    75  	// We currently don't send the preferred_address transport parameter,
    76  	// so we can issue (limit - 1) connection IDs.
    77  	for i := uint64(len(m.activeSrcConnIDs)); i < min(limit, protocol.MaxIssuedConnectionIDs); i++ {
    78  		if err := m.issueNewConnID(); err != nil {
    79  			return err
    80  		}
    81  	}
    82  	return nil
    83  }
    84  
    85  func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error {
    86  	if seq > m.highestSeq {
    87  		return &qerr.TransportError{
    88  			ErrorCode:    qerr.ProtocolViolation,
    89  			ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq),
    90  		}
    91  	}
    92  	connID, ok := m.activeSrcConnIDs[seq]
    93  	// We might already have deleted this connection ID, if this is a duplicate frame.
    94  	if !ok {
    95  		return nil
    96  	}
    97  	if connID == sentWithDestConnID {
    98  		return &qerr.TransportError{
    99  			ErrorCode:    qerr.ProtocolViolation,
   100  			ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
   101  		}
   102  	}
   103  	m.retireConnectionID(connID)
   104  	delete(m.activeSrcConnIDs, seq)
   105  	// Don't issue a replacement for the initial connection ID.
   106  	if seq == 0 {
   107  		return nil
   108  	}
   109  
   110  	if _, ok := m.generator.(*protocol.PriorityConnectionIDGenerator); ok {
   111  		// PRIO_PACKS_TAG
   112  		// if the retired connection ID had the same priority as the next one to be issued
   113  		// we need to set the next priority to the one of the retired connection ID
   114  		prio := connID.Bytes()[0]
   115  		m.generator.(*protocol.PriorityConnectionIDGenerator).NextPriority = int8(prio)
   116  		m.generator.(*protocol.PriorityConnectionIDGenerator).NextPriorityValid = true
   117  	}
   118  
   119  	return m.issueNewConnID()
   120  }
   121  
   122  func (m *connIDGenerator) issueNewConnID() error {
   123  	connID, err := m.generator.GenerateConnectionID()
   124  	if err != nil {
   125  		return err
   126  	}
   127  	m.activeSrcConnIDs[m.highestSeq+1] = connID
   128  	m.addConnectionID(connID)
   129  	m.queueControlFrame(&wire.NewConnectionIDFrame{
   130  		SequenceNumber:      m.highestSeq + 1,
   131  		ConnectionID:        connID,
   132  		StatelessResetToken: m.getStatelessResetToken(connID),
   133  	})
   134  	m.highestSeq++
   135  	return nil
   136  }
   137  
   138  func (m *connIDGenerator) SetHandshakeComplete() {
   139  	if m.initialClientDestConnID != nil {
   140  		m.retireConnectionID(*m.initialClientDestConnID)
   141  		m.initialClientDestConnID = nil
   142  	}
   143  }
   144  
   145  func (m *connIDGenerator) RemoveAll() {
   146  	if m.initialClientDestConnID != nil {
   147  		m.removeConnectionID(*m.initialClientDestConnID)
   148  	}
   149  	for _, connID := range m.activeSrcConnIDs {
   150  		m.removeConnectionID(connID)
   151  	}
   152  }
   153  
   154  func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
   155  	connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
   156  	if m.initialClientDestConnID != nil {
   157  		connIDs = append(connIDs, *m.initialClientDestConnID)
   158  	}
   159  	for _, connID := range m.activeSrcConnIDs {
   160  		connIDs = append(connIDs, connID)
   161  	}
   162  	m.replaceWithClosed(connIDs, connClose)
   163  }