github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/conn_id_generator.go (about)

     1  package quic
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol"
     7  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/qerr"
     8  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils"
     9  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/wire"
    10  )
    11  
    12  type connIDGenerator struct {
    13  	connIDLen  int
    14  	highestSeq uint64
    15  
    16  	activeSrcConnIDs        map[uint64]protocol.ConnectionID
    17  	initialClientDestConnID protocol.ConnectionID
    18  
    19  	addConnectionID        func(protocol.ConnectionID)
    20  	getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
    21  	removeConnectionID     func(protocol.ConnectionID)
    22  	retireConnectionID     func(protocol.ConnectionID)
    23  	replaceWithClosed      func(protocol.ConnectionID, packetHandler)
    24  	queueControlFrame      func(wire.Frame)
    25  
    26  	version protocol.VersionNumber
    27  }
    28  
    29  func newConnIDGenerator(
    30  	initialConnectionID protocol.ConnectionID,
    31  	initialClientDestConnID protocol.ConnectionID, // nil for the client
    32  	addConnectionID func(protocol.ConnectionID),
    33  	getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
    34  	removeConnectionID func(protocol.ConnectionID),
    35  	retireConnectionID func(protocol.ConnectionID),
    36  	replaceWithClosed func(protocol.ConnectionID, packetHandler),
    37  	queueControlFrame func(wire.Frame),
    38  	version protocol.VersionNumber,
    39  ) *connIDGenerator {
    40  	m := &connIDGenerator{
    41  		connIDLen:              initialConnectionID.Len(),
    42  		activeSrcConnIDs:       make(map[uint64]protocol.ConnectionID),
    43  		addConnectionID:        addConnectionID,
    44  		getStatelessResetToken: getStatelessResetToken,
    45  		removeConnectionID:     removeConnectionID,
    46  		retireConnectionID:     retireConnectionID,
    47  		replaceWithClosed:      replaceWithClosed,
    48  		queueControlFrame:      queueControlFrame,
    49  		version:                version,
    50  	}
    51  	m.activeSrcConnIDs[0] = initialConnectionID
    52  	m.initialClientDestConnID = initialClientDestConnID
    53  	return m
    54  }
    55  
    56  func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
    57  	if m.connIDLen == 0 {
    58  		return nil
    59  	}
    60  	// The active_connection_id_limit transport parameter is the number of
    61  	// connection IDs the peer will store. This limit includes the connection ID
    62  	// used during the handshake, and the one sent in the preferred_address
    63  	// transport parameter.
    64  	// We currently don't send the preferred_address transport parameter,
    65  	// so we can issue (limit - 1) connection IDs.
    66  	for i := uint64(len(m.activeSrcConnIDs)); i < utils.MinUint64(limit, protocol.MaxIssuedConnectionIDs); i++ {
    67  		if err := m.issueNewConnID(); err != nil {
    68  			return err
    69  		}
    70  	}
    71  	return nil
    72  }
    73  
    74  func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error {
    75  	if seq > m.highestSeq {
    76  		return &qerr.TransportError{
    77  			ErrorCode:    qerr.ProtocolViolation,
    78  			ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq),
    79  		}
    80  	}
    81  	connID, ok := m.activeSrcConnIDs[seq]
    82  	// We might already have deleted this connection ID, if this is a duplicate frame.
    83  	if !ok {
    84  		return nil
    85  	}
    86  	if connID.Equal(sentWithDestConnID) {
    87  		return &qerr.TransportError{
    88  			ErrorCode:    qerr.ProtocolViolation,
    89  			ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
    90  		}
    91  	}
    92  	m.retireConnectionID(connID)
    93  	delete(m.activeSrcConnIDs, seq)
    94  	// Don't issue a replacement for the initial connection ID.
    95  	if seq == 0 {
    96  		return nil
    97  	}
    98  	return m.issueNewConnID()
    99  }
   100  
   101  func (m *connIDGenerator) issueNewConnID() error {
   102  	connID, err := protocol.GenerateConnectionID(m.connIDLen)
   103  	if err != nil {
   104  		return err
   105  	}
   106  	m.activeSrcConnIDs[m.highestSeq+1] = connID
   107  	m.addConnectionID(connID)
   108  	m.queueControlFrame(&wire.NewConnectionIDFrame{
   109  		SequenceNumber:      m.highestSeq + 1,
   110  		ConnectionID:        connID,
   111  		StatelessResetToken: m.getStatelessResetToken(connID),
   112  	})
   113  	m.highestSeq++
   114  	return nil
   115  }
   116  
   117  func (m *connIDGenerator) SetHandshakeComplete() {
   118  	if m.initialClientDestConnID != nil {
   119  		m.retireConnectionID(m.initialClientDestConnID)
   120  		m.initialClientDestConnID = nil
   121  	}
   122  }
   123  
   124  func (m *connIDGenerator) RemoveAll() {
   125  	if m.initialClientDestConnID != nil {
   126  		m.removeConnectionID(m.initialClientDestConnID)
   127  	}
   128  	for _, connID := range m.activeSrcConnIDs {
   129  		m.removeConnectionID(connID)
   130  	}
   131  }
   132  
   133  func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
   134  	if m.initialClientDestConnID != nil {
   135  		m.replaceWithClosed(m.initialClientDestConnID, handler)
   136  	}
   137  	for _, connID := range m.activeSrcConnIDs {
   138  		m.replaceWithClosed(connID, handler)
   139  	}
   140  }