github.com/sagernet/quic-go@v0.43.1-beta.1/ech/conn_id_generator.go (about) 1 package quic 2 3 import ( 4 "fmt" 5 6 "github.com/sagernet/quic-go/internal/protocol" 7 "github.com/sagernet/quic-go/internal/qerr" 8 "github.com/sagernet/quic-go/internal/utils" 9 "github.com/sagernet/quic-go/internal/wire" 10 ) 11 12 type connIDGenerator struct { 13 generator ConnectionIDGenerator 14 highestSeq uint64 15 16 activeSrcConnIDs map[uint64]protocol.ConnectionID 17 initialClientDestConnID *protocol.ConnectionID // nil for the client 18 19 addConnectionID func(protocol.ConnectionID) 20 getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken 21 removeConnectionID func(protocol.ConnectionID) 22 retireConnectionID func(protocol.ConnectionID) 23 replaceWithClosed func([]protocol.ConnectionID, []byte) 24 queueControlFrame func(wire.Frame) 25 } 26 27 func newConnIDGenerator( 28 initialConnectionID protocol.ConnectionID, 29 initialClientDestConnID *protocol.ConnectionID, // nil for the client 30 addConnectionID func(protocol.ConnectionID), 31 getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, 32 removeConnectionID func(protocol.ConnectionID), 33 retireConnectionID func(protocol.ConnectionID), 34 replaceWithClosed func([]protocol.ConnectionID, []byte), 35 queueControlFrame func(wire.Frame), 36 generator ConnectionIDGenerator, 37 ) *connIDGenerator { 38 m := &connIDGenerator{ 39 generator: generator, 40 activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), 41 addConnectionID: addConnectionID, 42 getStatelessResetToken: getStatelessResetToken, 43 removeConnectionID: removeConnectionID, 44 retireConnectionID: retireConnectionID, 45 replaceWithClosed: replaceWithClosed, 46 queueControlFrame: queueControlFrame, 47 } 48 m.activeSrcConnIDs[0] = initialConnectionID 49 m.initialClientDestConnID = initialClientDestConnID 50 return m 51 } 52 53 func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { 54 if m.generator.ConnectionIDLen() == 0 { 55 return nil 56 } 57 // The active_connection_id_limit transport parameter is the number of 58 // connection IDs the peer will store. This limit includes the connection ID 59 // used during the handshake, and the one sent in the preferred_address 60 // transport parameter. 61 // We currently don't send the preferred_address transport parameter, 62 // so we can issue (limit - 1) connection IDs. 63 for i := uint64(len(m.activeSrcConnIDs)); i < utils.Min(limit, protocol.MaxIssuedConnectionIDs); i++ { 64 if err := m.issueNewConnID(); err != nil { 65 return err 66 } 67 } 68 return nil 69 } 70 71 func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error { 72 if seq > m.highestSeq { 73 return &qerr.TransportError{ 74 ErrorCode: qerr.ProtocolViolation, 75 ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq), 76 } 77 } 78 connID, ok := m.activeSrcConnIDs[seq] 79 // We might already have deleted this connection ID, if this is a duplicate frame. 80 if !ok { 81 return nil 82 } 83 if connID == sentWithDestConnID { 84 return &qerr.TransportError{ 85 ErrorCode: qerr.ProtocolViolation, 86 ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), 87 } 88 } 89 m.retireConnectionID(connID) 90 delete(m.activeSrcConnIDs, seq) 91 // Don't issue a replacement for the initial connection ID. 92 if seq == 0 { 93 return nil 94 } 95 return m.issueNewConnID() 96 } 97 98 func (m *connIDGenerator) issueNewConnID() error { 99 connID, err := m.generator.GenerateConnectionID() 100 if err != nil { 101 return err 102 } 103 m.activeSrcConnIDs[m.highestSeq+1] = connID 104 m.addConnectionID(connID) 105 m.queueControlFrame(&wire.NewConnectionIDFrame{ 106 SequenceNumber: m.highestSeq + 1, 107 ConnectionID: connID, 108 StatelessResetToken: m.getStatelessResetToken(connID), 109 }) 110 m.highestSeq++ 111 return nil 112 } 113 114 func (m *connIDGenerator) SetHandshakeComplete() { 115 if m.initialClientDestConnID != nil { 116 m.retireConnectionID(*m.initialClientDestConnID) 117 m.initialClientDestConnID = nil 118 } 119 } 120 121 func (m *connIDGenerator) RemoveAll() { 122 if m.initialClientDestConnID != nil { 123 m.removeConnectionID(*m.initialClientDestConnID) 124 } 125 for _, connID := range m.activeSrcConnIDs { 126 m.removeConnectionID(connID) 127 } 128 } 129 130 func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) { 131 connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) 132 if m.initialClientDestConnID != nil { 133 connIDs = append(connIDs, *m.initialClientDestConnID) 134 } 135 for _, connID := range m.activeSrcConnIDs { 136 connIDs = append(connIDs, connID) 137 } 138 m.replaceWithClosed(connIDs, connClose) 139 }