github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/internal/wire/header.go (about) 1 package wire 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "fmt" 8 "io" 9 10 "github.com/daeuniverse/quic-go/internal/protocol" 11 "github.com/daeuniverse/quic-go/internal/utils" 12 "github.com/daeuniverse/quic-go/quicvarint" 13 ) 14 15 // ParseConnectionID parses the destination connection ID of a packet. 16 func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { 17 if len(data) == 0 { 18 return protocol.ConnectionID{}, io.EOF 19 } 20 if !IsLongHeaderPacket(data[0]) { 21 if len(data) < shortHeaderConnIDLen+1 { 22 return protocol.ConnectionID{}, io.EOF 23 } 24 return protocol.ParseConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil 25 } 26 if len(data) < 6 { 27 return protocol.ConnectionID{}, io.EOF 28 } 29 destConnIDLen := int(data[5]) 30 if destConnIDLen > protocol.MaxConnIDLen { 31 return protocol.ConnectionID{}, protocol.ErrInvalidConnectionIDLen 32 } 33 if len(data) < 6+destConnIDLen { 34 return protocol.ConnectionID{}, io.EOF 35 } 36 return protocol.ParseConnectionID(data[6 : 6+destConnIDLen]), nil 37 } 38 39 // ParseArbitraryLenConnectionIDs parses the most general form of a Long Header packet, 40 // using only the version-independent packet format as described in Section 5.1 of RFC 8999: 41 // https://datatracker.ietf.org/doc/html/rfc8999#section-5.1. 42 // This function should only be called on Long Header packets for which we don't support the version. 43 func ParseArbitraryLenConnectionIDs(data []byte) (bytesParsed int, dest, src protocol.ArbitraryLenConnectionID, _ error) { 44 r := bytes.NewReader(data) 45 remaining := r.Len() 46 src, dest, err := parseArbitraryLenConnectionIDs(r) 47 return remaining - r.Len(), src, dest, err 48 } 49 50 func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.ArbitraryLenConnectionID, _ error) { 51 r.Seek(5, io.SeekStart) // skip first byte and version field 52 destConnIDLen, err := r.ReadByte() 53 if err != nil { 54 return nil, nil, err 55 } 56 destConnID := make(protocol.ArbitraryLenConnectionID, destConnIDLen) 57 if _, err := io.ReadFull(r, destConnID); err != nil { 58 if err == io.ErrUnexpectedEOF { 59 err = io.EOF 60 } 61 return nil, nil, err 62 } 63 srcConnIDLen, err := r.ReadByte() 64 if err != nil { 65 return nil, nil, err 66 } 67 srcConnID := make(protocol.ArbitraryLenConnectionID, srcConnIDLen) 68 if _, err := io.ReadFull(r, srcConnID); err != nil { 69 if err == io.ErrUnexpectedEOF { 70 err = io.EOF 71 } 72 return nil, nil, err 73 } 74 return destConnID, srcConnID, nil 75 } 76 77 func IsPotentialQUICPacket(firstByte byte) bool { 78 return firstByte&0x40 > 0 79 } 80 81 // IsLongHeaderPacket says if this is a Long Header packet 82 func IsLongHeaderPacket(firstByte byte) bool { 83 return firstByte&0x80 > 0 84 } 85 86 // ParseVersion parses the QUIC version. 87 // It should only be called for Long Header packets (Short Header packets don't contain a version number). 88 func ParseVersion(data []byte) (protocol.Version, error) { 89 if len(data) < 5 { 90 return 0, io.EOF 91 } 92 return protocol.Version(binary.BigEndian.Uint32(data[1:5])), nil 93 } 94 95 // IsVersionNegotiationPacket says if this is a version negotiation packet 96 func IsVersionNegotiationPacket(b []byte) bool { 97 if len(b) < 5 { 98 return false 99 } 100 return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 101 } 102 103 // Is0RTTPacket says if this is a 0-RTT packet. 104 // A packet sent with a version we don't understand can never be a 0-RTT packet. 105 func Is0RTTPacket(b []byte) bool { 106 if len(b) < 5 { 107 return false 108 } 109 if !IsLongHeaderPacket(b[0]) { 110 return false 111 } 112 version := protocol.Version(binary.BigEndian.Uint32(b[1:5])) 113 //nolint:exhaustive // We only need to test QUIC versions that we support. 114 switch version { 115 case protocol.Version1: 116 return b[0]>>4&0b11 == 0b01 117 case protocol.Version2: 118 return b[0]>>4&0b11 == 0b10 119 default: 120 return false 121 } 122 } 123 124 var ErrUnsupportedVersion = errors.New("unsupported version") 125 126 // The Header is the version independent part of the header 127 type Header struct { 128 typeByte byte 129 Type protocol.PacketType 130 131 Version protocol.Version 132 SrcConnectionID protocol.ConnectionID 133 DestConnectionID protocol.ConnectionID 134 135 Length protocol.ByteCount 136 137 Token []byte 138 139 parsedLen protocol.ByteCount // how many bytes were read while parsing this header 140 } 141 142 // ParsePacket parses a packet. 143 // If the packet has a long header, the packet is cut according to the length field. 144 // If we understand the version, the packet is header up unto the packet number. 145 // Otherwise, only the invariant part of the header is parsed. 146 func ParsePacket(data []byte) (*Header, []byte, []byte, error) { 147 if len(data) == 0 || !IsLongHeaderPacket(data[0]) { 148 return nil, nil, nil, errors.New("not a long header packet") 149 } 150 hdr, err := parseHeader(bytes.NewReader(data)) 151 if err != nil { 152 if err == ErrUnsupportedVersion { 153 return hdr, nil, nil, ErrUnsupportedVersion 154 } 155 return nil, nil, nil, err 156 } 157 if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length { 158 return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) 159 } 160 packetLen := int(hdr.ParsedLen() + hdr.Length) 161 return hdr, data[:packetLen], data[packetLen:], nil 162 } 163 164 // ParseHeader parses the header. 165 // For short header packets: up to the packet number. 166 // For long header packets: 167 // * if we understand the version: up to the packet number 168 // * if not, only the invariant part of the header 169 func parseHeader(b *bytes.Reader) (*Header, error) { 170 startLen := b.Len() 171 typeByte, err := b.ReadByte() 172 if err != nil { 173 return nil, err 174 } 175 176 h := &Header{typeByte: typeByte} 177 err = h.parseLongHeader(b) 178 h.parsedLen = protocol.ByteCount(startLen - b.Len()) 179 return h, err 180 } 181 182 func (h *Header) parseLongHeader(b *bytes.Reader) error { 183 v, err := utils.BigEndian.ReadUint32(b) 184 if err != nil { 185 return err 186 } 187 h.Version = protocol.Version(v) 188 if h.Version != 0 && h.typeByte&0x40 == 0 { 189 return errors.New("not a QUIC packet") 190 } 191 destConnIDLen, err := b.ReadByte() 192 if err != nil { 193 return err 194 } 195 h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen)) 196 if err != nil { 197 return err 198 } 199 srcConnIDLen, err := b.ReadByte() 200 if err != nil { 201 return err 202 } 203 h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen)) 204 if err != nil { 205 return err 206 } 207 if h.Version == 0 { // version negotiation packet 208 return nil 209 } 210 // If we don't understand the version, we have no idea how to interpret the rest of the bytes 211 if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) { 212 return ErrUnsupportedVersion 213 } 214 215 if h.Version == protocol.Version2 { 216 switch h.typeByte >> 4 & 0b11 { 217 case 0b00: 218 h.Type = protocol.PacketTypeRetry 219 case 0b01: 220 h.Type = protocol.PacketTypeInitial 221 case 0b10: 222 h.Type = protocol.PacketType0RTT 223 case 0b11: 224 h.Type = protocol.PacketTypeHandshake 225 } 226 } else { 227 switch h.typeByte >> 4 & 0b11 { 228 case 0b00: 229 h.Type = protocol.PacketTypeInitial 230 case 0b01: 231 h.Type = protocol.PacketType0RTT 232 case 0b10: 233 h.Type = protocol.PacketTypeHandshake 234 case 0b11: 235 h.Type = protocol.PacketTypeRetry 236 } 237 } 238 239 if h.Type == protocol.PacketTypeRetry { 240 tokenLen := b.Len() - 16 241 if tokenLen <= 0 { 242 return io.EOF 243 } 244 h.Token = make([]byte, tokenLen) 245 if _, err := io.ReadFull(b, h.Token); err != nil { 246 return err 247 } 248 _, err := b.Seek(16, io.SeekCurrent) 249 return err 250 } 251 252 if h.Type == protocol.PacketTypeInitial { 253 tokenLen, err := quicvarint.Read(b) 254 if err != nil { 255 return err 256 } 257 if tokenLen > uint64(b.Len()) { 258 return io.EOF 259 } 260 h.Token = make([]byte, tokenLen) 261 if _, err := io.ReadFull(b, h.Token); err != nil { 262 return err 263 } 264 } 265 266 pl, err := quicvarint.Read(b) 267 if err != nil { 268 return err 269 } 270 h.Length = protocol.ByteCount(pl) 271 return nil 272 } 273 274 // ParsedLen returns the number of bytes that were consumed when parsing the header 275 func (h *Header) ParsedLen() protocol.ByteCount { 276 return h.parsedLen 277 } 278 279 // ParseExtended parses the version dependent part of the header. 280 // The Reader has to be set such that it points to the first byte of the header. 281 func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) { 282 extHdr := h.toExtendedHeader() 283 reservedBitsValid, err := extHdr.parse(b, ver) 284 if err != nil { 285 return nil, err 286 } 287 if !reservedBitsValid { 288 return extHdr, ErrInvalidReservedBits 289 } 290 return extHdr, nil 291 } 292 293 func (h *Header) toExtendedHeader() *ExtendedHeader { 294 return &ExtendedHeader{Header: *h} 295 } 296 297 // PacketType is the type of the packet, for logging purposes 298 func (h *Header) PacketType() string { 299 return h.Type.String() 300 }