github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/internal/wire/extended_header.go (about) 1 package wire 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "fmt" 8 "io" 9 10 "github.com/apernet/quic-go/internal/protocol" 11 "github.com/apernet/quic-go/internal/utils" 12 "github.com/apernet/quic-go/quicvarint" 13 ) 14 15 // ErrInvalidReservedBits is returned when the reserved bits are incorrect. 16 // When this error is returned, parsing continues, and an ExtendedHeader is returned. 17 // This is necessary because we need to decrypt the packet in that case, 18 // in order to avoid a timing side-channel. 19 var ErrInvalidReservedBits = errors.New("invalid reserved bits") 20 21 // ExtendedHeader is the header of a QUIC packet. 22 type ExtendedHeader struct { 23 Header 24 25 typeByte byte 26 27 KeyPhase protocol.KeyPhaseBit 28 29 PacketNumberLen protocol.PacketNumberLen 30 PacketNumber protocol.PacketNumber 31 32 parsedLen protocol.ByteCount 33 } 34 35 func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.Version) (bool /* reserved bits valid */, error) { 36 startLen := b.Len() 37 // read the (now unencrypted) first byte 38 var err error 39 h.typeByte, err = b.ReadByte() 40 if err != nil { 41 return false, err 42 } 43 if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil { 44 return false, err 45 } 46 reservedBitsValid, err := h.parseLongHeader(b, v) 47 if err != nil { 48 return false, err 49 } 50 h.parsedLen = protocol.ByteCount(startLen - b.Len()) 51 return reservedBitsValid, err 52 } 53 54 func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.Version) (bool /* reserved bits valid */, error) { 55 if err := h.readPacketNumber(b); err != nil { 56 return false, err 57 } 58 if h.typeByte&0xc != 0 { 59 return false, nil 60 } 61 return true, nil 62 } 63 64 func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { 65 h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 66 switch h.PacketNumberLen { 67 case protocol.PacketNumberLen1: 68 n, err := b.ReadByte() 69 if err != nil { 70 return err 71 } 72 h.PacketNumber = protocol.PacketNumber(n) 73 case protocol.PacketNumberLen2: 74 n, err := utils.BigEndian.ReadUint16(b) 75 if err != nil { 76 return err 77 } 78 h.PacketNumber = protocol.PacketNumber(n) 79 case protocol.PacketNumberLen3: 80 n, err := utils.BigEndian.ReadUint24(b) 81 if err != nil { 82 return err 83 } 84 h.PacketNumber = protocol.PacketNumber(n) 85 case protocol.PacketNumberLen4: 86 n, err := utils.BigEndian.ReadUint32(b) 87 if err != nil { 88 return err 89 } 90 h.PacketNumber = protocol.PacketNumber(n) 91 default: 92 return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) 93 } 94 return nil 95 } 96 97 // Append appends the Header. 98 func (h *ExtendedHeader) Append(b []byte, v protocol.Version) ([]byte, error) { 99 if h.DestConnectionID.Len() > protocol.MaxConnIDLen { 100 return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) 101 } 102 if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { 103 return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) 104 } 105 106 var packetType uint8 107 if v == protocol.Version2 { 108 //nolint:exhaustive 109 switch h.Type { 110 case protocol.PacketTypeInitial: 111 packetType = 0b01 112 case protocol.PacketType0RTT: 113 packetType = 0b10 114 case protocol.PacketTypeHandshake: 115 packetType = 0b11 116 case protocol.PacketTypeRetry: 117 packetType = 0b00 118 } 119 } else { 120 //nolint:exhaustive 121 switch h.Type { 122 case protocol.PacketTypeInitial: 123 packetType = 0b00 124 case protocol.PacketType0RTT: 125 packetType = 0b01 126 case protocol.PacketTypeHandshake: 127 packetType = 0b10 128 case protocol.PacketTypeRetry: 129 packetType = 0b11 130 } 131 } 132 firstByte := 0xc0 | packetType<<4 133 if h.Type != protocol.PacketTypeRetry { 134 // Retry packets don't have a packet number 135 firstByte |= uint8(h.PacketNumberLen - 1) 136 } 137 138 b = append(b, firstByte) 139 b = append(b, make([]byte, 4)...) 140 binary.BigEndian.PutUint32(b[len(b)-4:], uint32(h.Version)) 141 b = append(b, uint8(h.DestConnectionID.Len())) 142 b = append(b, h.DestConnectionID.Bytes()...) 143 b = append(b, uint8(h.SrcConnectionID.Len())) 144 b = append(b, h.SrcConnectionID.Bytes()...) 145 146 //nolint:exhaustive 147 switch h.Type { 148 case protocol.PacketTypeRetry: 149 b = append(b, h.Token...) 150 return b, nil 151 case protocol.PacketTypeInitial: 152 b = quicvarint.Append(b, uint64(len(h.Token))) 153 b = append(b, h.Token...) 154 } 155 b = quicvarint.AppendWithLen(b, uint64(h.Length), 2) 156 return appendPacketNumber(b, h.PacketNumber, h.PacketNumberLen) 157 } 158 159 // ParsedLen returns the number of bytes that were consumed when parsing the header 160 func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { 161 return h.parsedLen 162 } 163 164 // GetLength determines the length of the Header. 165 func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount { 166 length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ 167 if h.Type == protocol.PacketTypeInitial { 168 length += protocol.ByteCount(quicvarint.Len(uint64(len(h.Token))) + len(h.Token)) 169 } 170 return length 171 } 172 173 // Log logs the Header 174 func (h *ExtendedHeader) Log(logger utils.Logger) { 175 var token string 176 if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { 177 if len(h.Token) == 0 { 178 token = "Token: (empty), " 179 } else { 180 token = fmt.Sprintf("Token: %#x, ", h.Token) 181 } 182 if h.Type == protocol.PacketTypeRetry { 183 logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version) 184 return 185 } 186 } 187 logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) 188 } 189 190 func appendPacketNumber(b []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) ([]byte, error) { 191 switch pnLen { 192 case protocol.PacketNumberLen1: 193 b = append(b, uint8(pn)) 194 case protocol.PacketNumberLen2: 195 buf := make([]byte, 2) 196 binary.BigEndian.PutUint16(buf, uint16(pn)) 197 b = append(b, buf...) 198 case protocol.PacketNumberLen3: 199 buf := make([]byte, 4) 200 binary.BigEndian.PutUint32(buf, uint32(pn)) 201 b = append(b, buf[1:]...) 202 case protocol.PacketNumberLen4: 203 buf := make([]byte, 4) 204 binary.BigEndian.PutUint32(buf, uint32(pn)) 205 b = append(b, buf...) 206 default: 207 return nil, fmt.Errorf("invalid packet number length: %d", pnLen) 208 } 209 return b, nil 210 }