github.com/sagernet/quic-go@v0.43.1-beta.1/internal/wire/header.go (about)

     1  package wire
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  
    10  	"github.com/sagernet/quic-go/internal/protocol"
    11  	"github.com/sagernet/quic-go/internal/utils"
    12  	"github.com/sagernet/quic-go/quicvarint"
    13  )
    14  
    15  // ParseConnectionID parses the destination connection ID of a packet.
    16  func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) {
    17  	if len(data) == 0 {
    18  		return protocol.ConnectionID{}, io.EOF
    19  	}
    20  	if !IsLongHeaderPacket(data[0]) {
    21  		if len(data) < shortHeaderConnIDLen+1 {
    22  			return protocol.ConnectionID{}, io.EOF
    23  		}
    24  		return protocol.ParseConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil
    25  	}
    26  	if len(data) < 6 {
    27  		return protocol.ConnectionID{}, io.EOF
    28  	}
    29  	destConnIDLen := int(data[5])
    30  	if destConnIDLen > protocol.MaxConnIDLen {
    31  		return protocol.ConnectionID{}, protocol.ErrInvalidConnectionIDLen
    32  	}
    33  	if len(data) < 6+destConnIDLen {
    34  		return protocol.ConnectionID{}, io.EOF
    35  	}
    36  	return protocol.ParseConnectionID(data[6 : 6+destConnIDLen]), nil
    37  }
    38  
    39  // ParseArbitraryLenConnectionIDs parses the most general form of a Long Header packet,
    40  // using only the version-independent packet format as described in Section 5.1 of RFC 8999:
    41  // https://datatracker.ietf.org/doc/html/rfc8999#section-5.1.
    42  // This function should only be called on Long Header packets for which we don't support the version.
    43  func ParseArbitraryLenConnectionIDs(data []byte) (bytesParsed int, dest, src protocol.ArbitraryLenConnectionID, _ error) {
    44  	r := bytes.NewReader(data)
    45  	remaining := r.Len()
    46  	src, dest, err := parseArbitraryLenConnectionIDs(r)
    47  	return remaining - r.Len(), src, dest, err
    48  }
    49  
    50  func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.ArbitraryLenConnectionID, _ error) {
    51  	r.Seek(5, io.SeekStart) // skip first byte and version field
    52  	destConnIDLen, err := r.ReadByte()
    53  	if err != nil {
    54  		return nil, nil, err
    55  	}
    56  	destConnID := make(protocol.ArbitraryLenConnectionID, destConnIDLen)
    57  	if _, err := io.ReadFull(r, destConnID); err != nil {
    58  		if err == io.ErrUnexpectedEOF {
    59  			err = io.EOF
    60  		}
    61  		return nil, nil, err
    62  	}
    63  	srcConnIDLen, err := r.ReadByte()
    64  	if err != nil {
    65  		return nil, nil, err
    66  	}
    67  	srcConnID := make(protocol.ArbitraryLenConnectionID, srcConnIDLen)
    68  	if _, err := io.ReadFull(r, srcConnID); err != nil {
    69  		if err == io.ErrUnexpectedEOF {
    70  			err = io.EOF
    71  		}
    72  		return nil, nil, err
    73  	}
    74  	return destConnID, srcConnID, nil
    75  }
    76  
    77  func IsPotentialQUICPacket(firstByte byte) bool {
    78  	return firstByte&0x40 > 0
    79  }
    80  
    81  // IsLongHeaderPacket says if this is a Long Header packet
    82  func IsLongHeaderPacket(firstByte byte) bool {
    83  	return firstByte&0x80 > 0
    84  }
    85  
    86  // ParseVersion parses the QUIC version.
    87  // It should only be called for Long Header packets (Short Header packets don't contain a version number).
    88  func ParseVersion(data []byte) (protocol.Version, error) {
    89  	if len(data) < 5 {
    90  		return 0, io.EOF
    91  	}
    92  	return protocol.Version(binary.BigEndian.Uint32(data[1:5])), nil
    93  }
    94  
    95  // IsVersionNegotiationPacket says if this is a version negotiation packet
    96  func IsVersionNegotiationPacket(b []byte) bool {
    97  	if len(b) < 5 {
    98  		return false
    99  	}
   100  	return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
   101  }
   102  
   103  // Is0RTTPacket says if this is a 0-RTT packet.
   104  // A packet sent with a version we don't understand can never be a 0-RTT packet.
   105  func Is0RTTPacket(b []byte) bool {
   106  	if len(b) < 5 {
   107  		return false
   108  	}
   109  	if !IsLongHeaderPacket(b[0]) {
   110  		return false
   111  	}
   112  	version := protocol.Version(binary.BigEndian.Uint32(b[1:5]))
   113  	//nolint:exhaustive // We only need to test QUIC versions that we support.
   114  	switch version {
   115  	case protocol.Version1:
   116  		return b[0]>>4&0b11 == 0b01
   117  	case protocol.Version2:
   118  		return b[0]>>4&0b11 == 0b10
   119  	default:
   120  		return false
   121  	}
   122  }
   123  
   124  var ErrUnsupportedVersion = errors.New("unsupported version")
   125  
   126  // The Header is the version independent part of the header
   127  type Header struct {
   128  	typeByte byte
   129  	Type     protocol.PacketType
   130  
   131  	Version          protocol.Version
   132  	SrcConnectionID  protocol.ConnectionID
   133  	DestConnectionID protocol.ConnectionID
   134  
   135  	Length protocol.ByteCount
   136  
   137  	Token []byte
   138  
   139  	parsedLen protocol.ByteCount // how many bytes were read while parsing this header
   140  }
   141  
   142  // ParsePacket parses a packet.
   143  // If the packet has a long header, the packet is cut according to the length field.
   144  // If we understand the version, the packet is header up unto the packet number.
   145  // Otherwise, only the invariant part of the header is parsed.
   146  func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
   147  	if len(data) == 0 || !IsLongHeaderPacket(data[0]) {
   148  		return nil, nil, nil, errors.New("not a long header packet")
   149  	}
   150  	hdr, err := parseHeader(bytes.NewReader(data))
   151  	if err != nil {
   152  		if err == ErrUnsupportedVersion {
   153  			return hdr, nil, nil, ErrUnsupportedVersion
   154  		}
   155  		return nil, nil, nil, err
   156  	}
   157  	if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
   158  		return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
   159  	}
   160  	packetLen := int(hdr.ParsedLen() + hdr.Length)
   161  	return hdr, data[:packetLen], data[packetLen:], nil
   162  }
   163  
   164  // ParseHeader parses the header.
   165  // For short header packets: up to the packet number.
   166  // For long header packets:
   167  // * if we understand the version: up to the packet number
   168  // * if not, only the invariant part of the header
   169  func parseHeader(b *bytes.Reader) (*Header, error) {
   170  	startLen := b.Len()
   171  	typeByte, err := b.ReadByte()
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  
   176  	h := &Header{typeByte: typeByte}
   177  	err = h.parseLongHeader(b)
   178  	h.parsedLen = protocol.ByteCount(startLen - b.Len())
   179  	return h, err
   180  }
   181  
   182  func (h *Header) parseLongHeader(b *bytes.Reader) error {
   183  	v, err := utils.BigEndian.ReadUint32(b)
   184  	if err != nil {
   185  		return err
   186  	}
   187  	h.Version = protocol.Version(v)
   188  	if h.Version != 0 && h.typeByte&0x40 == 0 {
   189  		return errors.New("not a QUIC packet")
   190  	}
   191  	destConnIDLen, err := b.ReadByte()
   192  	if err != nil {
   193  		return err
   194  	}
   195  	h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen))
   196  	if err != nil {
   197  		return err
   198  	}
   199  	srcConnIDLen, err := b.ReadByte()
   200  	if err != nil {
   201  		return err
   202  	}
   203  	h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen))
   204  	if err != nil {
   205  		return err
   206  	}
   207  	if h.Version == 0 { // version negotiation packet
   208  		return nil
   209  	}
   210  	// If we don't understand the version, we have no idea how to interpret the rest of the bytes
   211  	if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
   212  		return ErrUnsupportedVersion
   213  	}
   214  
   215  	if h.Version == protocol.Version2 {
   216  		switch h.typeByte >> 4 & 0b11 {
   217  		case 0b00:
   218  			h.Type = protocol.PacketTypeRetry
   219  		case 0b01:
   220  			h.Type = protocol.PacketTypeInitial
   221  		case 0b10:
   222  			h.Type = protocol.PacketType0RTT
   223  		case 0b11:
   224  			h.Type = protocol.PacketTypeHandshake
   225  		}
   226  	} else {
   227  		switch h.typeByte >> 4 & 0b11 {
   228  		case 0b00:
   229  			h.Type = protocol.PacketTypeInitial
   230  		case 0b01:
   231  			h.Type = protocol.PacketType0RTT
   232  		case 0b10:
   233  			h.Type = protocol.PacketTypeHandshake
   234  		case 0b11:
   235  			h.Type = protocol.PacketTypeRetry
   236  		}
   237  	}
   238  
   239  	if h.Type == protocol.PacketTypeRetry {
   240  		tokenLen := b.Len() - 16
   241  		if tokenLen <= 0 {
   242  			return io.EOF
   243  		}
   244  		h.Token = make([]byte, tokenLen)
   245  		if _, err := io.ReadFull(b, h.Token); err != nil {
   246  			return err
   247  		}
   248  		_, err := b.Seek(16, io.SeekCurrent)
   249  		return err
   250  	}
   251  
   252  	if h.Type == protocol.PacketTypeInitial {
   253  		tokenLen, err := quicvarint.Read(b)
   254  		if err != nil {
   255  			return err
   256  		}
   257  		if tokenLen > uint64(b.Len()) {
   258  			return io.EOF
   259  		}
   260  		h.Token = make([]byte, tokenLen)
   261  		if _, err := io.ReadFull(b, h.Token); err != nil {
   262  			return err
   263  		}
   264  	}
   265  
   266  	pl, err := quicvarint.Read(b)
   267  	if err != nil {
   268  		return err
   269  	}
   270  	h.Length = protocol.ByteCount(pl)
   271  	return nil
   272  }
   273  
   274  // ParsedLen returns the number of bytes that were consumed when parsing the header
   275  func (h *Header) ParsedLen() protocol.ByteCount {
   276  	return h.parsedLen
   277  }
   278  
   279  // ParseExtended parses the version dependent part of the header.
   280  // The Reader has to be set such that it points to the first byte of the header.
   281  func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) {
   282  	extHdr := h.toExtendedHeader()
   283  	reservedBitsValid, err := extHdr.parse(b, ver)
   284  	if err != nil {
   285  		return nil, err
   286  	}
   287  	if !reservedBitsValid {
   288  		return extHdr, ErrInvalidReservedBits
   289  	}
   290  	return extHdr, nil
   291  }
   292  
   293  func (h *Header) toExtendedHeader() *ExtendedHeader {
   294  	return &ExtendedHeader{Header: *h}
   295  }
   296  
   297  // PacketType is the type of the packet, for logging purposes
   298  func (h *Header) PacketType() string {
   299  	return h.Type.String()
   300  }