github.com/quic-go/quic-go@v0.44.0/internal/wire/header.go (about)

     1  package wire
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  
    10  	"github.com/quic-go/quic-go/internal/protocol"
    11  	"github.com/quic-go/quic-go/quicvarint"
    12  )
    13  
    14  // ParseConnectionID parses the destination connection ID of a packet.
    15  func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) {
    16  	if len(data) == 0 {
    17  		return protocol.ConnectionID{}, io.EOF
    18  	}
    19  	if !IsLongHeaderPacket(data[0]) {
    20  		if len(data) < shortHeaderConnIDLen+1 {
    21  			return protocol.ConnectionID{}, io.EOF
    22  		}
    23  		return protocol.ParseConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil
    24  	}
    25  	if len(data) < 6 {
    26  		return protocol.ConnectionID{}, io.EOF
    27  	}
    28  	destConnIDLen := int(data[5])
    29  	if destConnIDLen > protocol.MaxConnIDLen {
    30  		return protocol.ConnectionID{}, protocol.ErrInvalidConnectionIDLen
    31  	}
    32  	if len(data) < 6+destConnIDLen {
    33  		return protocol.ConnectionID{}, io.EOF
    34  	}
    35  	return protocol.ParseConnectionID(data[6 : 6+destConnIDLen]), nil
    36  }
    37  
    38  // ParseArbitraryLenConnectionIDs parses the most general form of a Long Header packet,
    39  // using only the version-independent packet format as described in Section 5.1 of RFC 8999:
    40  // https://datatracker.ietf.org/doc/html/rfc8999#section-5.1.
    41  // This function should only be called on Long Header packets for which we don't support the version.
    42  func ParseArbitraryLenConnectionIDs(data []byte) (bytesParsed int, dest, src protocol.ArbitraryLenConnectionID, _ error) {
    43  	r := bytes.NewReader(data)
    44  	remaining := r.Len()
    45  	src, dest, err := parseArbitraryLenConnectionIDs(r)
    46  	return remaining - r.Len(), src, dest, err
    47  }
    48  
    49  func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.ArbitraryLenConnectionID, _ error) {
    50  	r.Seek(5, io.SeekStart) // skip first byte and version field
    51  	destConnIDLen, err := r.ReadByte()
    52  	if err != nil {
    53  		return nil, nil, err
    54  	}
    55  	destConnID := make(protocol.ArbitraryLenConnectionID, destConnIDLen)
    56  	if _, err := io.ReadFull(r, destConnID); err != nil {
    57  		if err == io.ErrUnexpectedEOF {
    58  			err = io.EOF
    59  		}
    60  		return nil, nil, err
    61  	}
    62  	srcConnIDLen, err := r.ReadByte()
    63  	if err != nil {
    64  		return nil, nil, err
    65  	}
    66  	srcConnID := make(protocol.ArbitraryLenConnectionID, srcConnIDLen)
    67  	if _, err := io.ReadFull(r, srcConnID); err != nil {
    68  		if err == io.ErrUnexpectedEOF {
    69  			err = io.EOF
    70  		}
    71  		return nil, nil, err
    72  	}
    73  	return destConnID, srcConnID, nil
    74  }
    75  
    76  func IsPotentialQUICPacket(firstByte byte) bool {
    77  	return firstByte&0x40 > 0
    78  }
    79  
    80  // IsLongHeaderPacket says if this is a Long Header packet
    81  func IsLongHeaderPacket(firstByte byte) bool {
    82  	return firstByte&0x80 > 0
    83  }
    84  
    85  // ParseVersion parses the QUIC version.
    86  // It should only be called for Long Header packets (Short Header packets don't contain a version number).
    87  func ParseVersion(data []byte) (protocol.Version, error) {
    88  	if len(data) < 5 {
    89  		return 0, io.EOF
    90  	}
    91  	return protocol.Version(binary.BigEndian.Uint32(data[1:5])), nil
    92  }
    93  
    94  // IsVersionNegotiationPacket says if this is a version negotiation packet
    95  func IsVersionNegotiationPacket(b []byte) bool {
    96  	if len(b) < 5 {
    97  		return false
    98  	}
    99  	return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
   100  }
   101  
   102  // Is0RTTPacket says if this is a 0-RTT packet.
   103  // A packet sent with a version we don't understand can never be a 0-RTT packet.
   104  func Is0RTTPacket(b []byte) bool {
   105  	if len(b) < 5 {
   106  		return false
   107  	}
   108  	if !IsLongHeaderPacket(b[0]) {
   109  		return false
   110  	}
   111  	version := protocol.Version(binary.BigEndian.Uint32(b[1:5]))
   112  	//nolint:exhaustive // We only need to test QUIC versions that we support.
   113  	switch version {
   114  	case protocol.Version1:
   115  		return b[0]>>4&0b11 == 0b01
   116  	case protocol.Version2:
   117  		return b[0]>>4&0b11 == 0b10
   118  	default:
   119  		return false
   120  	}
   121  }
   122  
   123  var ErrUnsupportedVersion = errors.New("unsupported version")
   124  
   125  // The Header is the version independent part of the header
   126  type Header struct {
   127  	typeByte byte
   128  	Type     protocol.PacketType
   129  
   130  	Version          protocol.Version
   131  	SrcConnectionID  protocol.ConnectionID
   132  	DestConnectionID protocol.ConnectionID
   133  
   134  	Length protocol.ByteCount
   135  
   136  	Token []byte
   137  
   138  	parsedLen protocol.ByteCount // how many bytes were read while parsing this header
   139  }
   140  
   141  // ParsePacket parses a long header packet.
   142  // The packet is cut according to the length field.
   143  // If we understand the version, the packet is parsed up unto the packet number.
   144  // Otherwise, only the invariant part of the header is parsed.
   145  func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
   146  	if len(data) == 0 || !IsLongHeaderPacket(data[0]) {
   147  		return nil, nil, nil, errors.New("not a long header packet")
   148  	}
   149  	hdr, err := parseHeader(data)
   150  	if err != nil {
   151  		if errors.Is(err, ErrUnsupportedVersion) {
   152  			return hdr, nil, nil, err
   153  		}
   154  		return nil, nil, nil, err
   155  	}
   156  	if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
   157  		return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
   158  	}
   159  	packetLen := int(hdr.ParsedLen() + hdr.Length)
   160  	return hdr, data[:packetLen], data[packetLen:], nil
   161  }
   162  
   163  // ParseHeader parses the header:
   164  // * if we understand the version: up to the packet number
   165  // * if not, only the invariant part of the header
   166  func parseHeader(b []byte) (*Header, error) {
   167  	if len(b) == 0 {
   168  		return nil, io.EOF
   169  	}
   170  	typeByte := b[0]
   171  
   172  	h := &Header{typeByte: typeByte}
   173  	l, err := h.parseLongHeader(b[1:])
   174  	h.parsedLen = protocol.ByteCount(l) + 1
   175  	return h, err
   176  }
   177  
   178  func (h *Header) parseLongHeader(b []byte) (int, error) {
   179  	startLen := len(b)
   180  	if len(b) < 5 {
   181  		return 0, io.EOF
   182  	}
   183  	h.Version = protocol.Version(binary.BigEndian.Uint32(b[:4]))
   184  	if h.Version != 0 && h.typeByte&0x40 == 0 {
   185  		return startLen - len(b), errors.New("not a QUIC packet")
   186  	}
   187  	destConnIDLen := int(b[4])
   188  	if destConnIDLen > protocol.MaxConnIDLen {
   189  		return startLen - len(b), protocol.ErrInvalidConnectionIDLen
   190  	}
   191  	b = b[5:]
   192  	if len(b) < destConnIDLen+1 {
   193  		return startLen - len(b), io.EOF
   194  	}
   195  	h.DestConnectionID = protocol.ParseConnectionID(b[:destConnIDLen])
   196  	srcConnIDLen := int(b[destConnIDLen])
   197  	if srcConnIDLen > protocol.MaxConnIDLen {
   198  		return startLen - len(b), protocol.ErrInvalidConnectionIDLen
   199  	}
   200  	b = b[destConnIDLen+1:]
   201  	if len(b) < srcConnIDLen {
   202  		return startLen - len(b), io.EOF
   203  	}
   204  	h.SrcConnectionID = protocol.ParseConnectionID(b[:srcConnIDLen])
   205  	b = b[srcConnIDLen:]
   206  	if h.Version == 0 { // version negotiation packet
   207  		return startLen - len(b), nil
   208  	}
   209  	// If we don't understand the version, we have no idea how to interpret the rest of the bytes
   210  	if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
   211  		return startLen - len(b), ErrUnsupportedVersion
   212  	}
   213  
   214  	if h.Version == protocol.Version2 {
   215  		switch h.typeByte >> 4 & 0b11 {
   216  		case 0b00:
   217  			h.Type = protocol.PacketTypeRetry
   218  		case 0b01:
   219  			h.Type = protocol.PacketTypeInitial
   220  		case 0b10:
   221  			h.Type = protocol.PacketType0RTT
   222  		case 0b11:
   223  			h.Type = protocol.PacketTypeHandshake
   224  		}
   225  	} else {
   226  		switch h.typeByte >> 4 & 0b11 {
   227  		case 0b00:
   228  			h.Type = protocol.PacketTypeInitial
   229  		case 0b01:
   230  			h.Type = protocol.PacketType0RTT
   231  		case 0b10:
   232  			h.Type = protocol.PacketTypeHandshake
   233  		case 0b11:
   234  			h.Type = protocol.PacketTypeRetry
   235  		}
   236  	}
   237  
   238  	if h.Type == protocol.PacketTypeRetry {
   239  		tokenLen := len(b) - 16
   240  		if tokenLen <= 0 {
   241  			return startLen - len(b), io.EOF
   242  		}
   243  		h.Token = make([]byte, tokenLen)
   244  		copy(h.Token, b[:tokenLen])
   245  		return startLen - len(b) + tokenLen + 16, nil
   246  	}
   247  
   248  	if h.Type == protocol.PacketTypeInitial {
   249  		tokenLen, n, err := quicvarint.Parse(b)
   250  		if err != nil {
   251  			return startLen - len(b), err
   252  		}
   253  		b = b[n:]
   254  		if tokenLen > uint64(len(b)) {
   255  			return startLen - len(b), io.EOF
   256  		}
   257  		h.Token = make([]byte, tokenLen)
   258  		copy(h.Token, b[:tokenLen])
   259  		b = b[tokenLen:]
   260  	}
   261  
   262  	pl, n, err := quicvarint.Parse(b)
   263  	if err != nil {
   264  		return 0, err
   265  	}
   266  	h.Length = protocol.ByteCount(pl)
   267  	return startLen - len(b) + n, nil
   268  }
   269  
   270  // ParsedLen returns the number of bytes that were consumed when parsing the header
   271  func (h *Header) ParsedLen() protocol.ByteCount {
   272  	return h.parsedLen
   273  }
   274  
   275  // ParseExtended parses the version dependent part of the header.
   276  // The Reader has to be set such that it points to the first byte of the header.
   277  func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) {
   278  	extHdr := h.toExtendedHeader()
   279  	reservedBitsValid, err := extHdr.parse(b, ver)
   280  	if err != nil {
   281  		return nil, err
   282  	}
   283  	if !reservedBitsValid {
   284  		return extHdr, ErrInvalidReservedBits
   285  	}
   286  	return extHdr, nil
   287  }
   288  
   289  func (h *Header) toExtendedHeader() *ExtendedHeader {
   290  	return &ExtendedHeader{Header: *h}
   291  }
   292  
   293  // PacketType is the type of the packet, for logging purposes
   294  func (h *Header) PacketType() string {
   295  	return h.Type.String()
   296  }