github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/conn_id_generator.go (about) 1 package quic 2 3 import ( 4 "fmt" 5 6 "github.com/daeuniverse/quic-go/internal/protocol" 7 "github.com/daeuniverse/quic-go/internal/qerr" 8 "github.com/daeuniverse/quic-go/internal/wire" 9 ) 10 11 type connIDGenerator struct { 12 generator ConnectionIDGenerator 13 highestSeq uint64 14 15 activeSrcConnIDs map[uint64]protocol.ConnectionID 16 initialClientDestConnID *protocol.ConnectionID // nil for the client 17 18 addConnectionID func(protocol.ConnectionID) 19 getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken 20 removeConnectionID func(protocol.ConnectionID) 21 retireConnectionID func(protocol.ConnectionID) 22 replaceWithClosed func([]protocol.ConnectionID, []byte) 23 queueControlFrame func(wire.Frame) 24 } 25 26 func newConnIDGenerator( 27 initialConnectionID protocol.ConnectionID, 28 initialClientDestConnID *protocol.ConnectionID, // nil for the client 29 addConnectionID func(protocol.ConnectionID), 30 getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, 31 removeConnectionID func(protocol.ConnectionID), 32 retireConnectionID func(protocol.ConnectionID), 33 replaceWithClosed func([]protocol.ConnectionID, []byte), 34 queueControlFrame func(wire.Frame), 35 generator ConnectionIDGenerator, 36 ) *connIDGenerator { 37 m := &connIDGenerator{ 38 generator: generator, 39 activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), 40 addConnectionID: addConnectionID, 41 getStatelessResetToken: getStatelessResetToken, 42 removeConnectionID: removeConnectionID, 43 retireConnectionID: retireConnectionID, 44 replaceWithClosed: replaceWithClosed, 45 queueControlFrame: queueControlFrame, 46 } 47 m.activeSrcConnIDs[0] = initialConnectionID 48 m.initialClientDestConnID = initialClientDestConnID 49 return m 50 } 51 52 func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { 53 if m.generator.ConnectionIDLen() == 0 { 54 return nil 55 } 56 // The active_connection_id_limit transport parameter is the number of 57 // connection IDs the peer will store. This limit includes the connection ID 58 // used during the handshake, and the one sent in the preferred_address 59 // transport parameter. 60 // We currently don't send the preferred_address transport parameter, 61 // so we can issue (limit - 1) connection IDs. 62 for i := uint64(len(m.activeSrcConnIDs)); i < min(limit, protocol.MaxIssuedConnectionIDs); i++ { 63 if err := m.issueNewConnID(); err != nil { 64 return err 65 } 66 } 67 return nil 68 } 69 70 func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error { 71 if seq > m.highestSeq { 72 return &qerr.TransportError{ 73 ErrorCode: qerr.ProtocolViolation, 74 ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq), 75 } 76 } 77 connID, ok := m.activeSrcConnIDs[seq] 78 // We might already have deleted this connection ID, if this is a duplicate frame. 79 if !ok { 80 return nil 81 } 82 if connID == sentWithDestConnID { 83 return &qerr.TransportError{ 84 ErrorCode: qerr.ProtocolViolation, 85 ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), 86 } 87 } 88 m.retireConnectionID(connID) 89 delete(m.activeSrcConnIDs, seq) 90 // Don't issue a replacement for the initial connection ID. 91 if seq == 0 { 92 return nil 93 } 94 return m.issueNewConnID() 95 } 96 97 func (m *connIDGenerator) issueNewConnID() error { 98 connID, err := m.generator.GenerateConnectionID() 99 if err != nil { 100 return err 101 } 102 m.activeSrcConnIDs[m.highestSeq+1] = connID 103 m.addConnectionID(connID) 104 m.queueControlFrame(&wire.NewConnectionIDFrame{ 105 SequenceNumber: m.highestSeq + 1, 106 ConnectionID: connID, 107 StatelessResetToken: m.getStatelessResetToken(connID), 108 }) 109 m.highestSeq++ 110 return nil 111 } 112 113 func (m *connIDGenerator) SetHandshakeComplete() { 114 if m.initialClientDestConnID != nil { 115 m.retireConnectionID(*m.initialClientDestConnID) 116 m.initialClientDestConnID = nil 117 } 118 } 119 120 func (m *connIDGenerator) RemoveAll() { 121 if m.initialClientDestConnID != nil { 122 m.removeConnectionID(*m.initialClientDestConnID) 123 } 124 for _, connID := range m.activeSrcConnIDs { 125 m.removeConnectionID(connID) 126 } 127 } 128 129 func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) { 130 connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) 131 if m.initialClientDestConnID != nil { 132 connIDs = append(connIDs, *m.initialClientDestConnID) 133 } 134 for _, connID := range m.activeSrcConnIDs { 135 connIDs = append(connIDs, connID) 136 } 137 m.replaceWithClosed(connIDs, connClose) 138 }