github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/internal/wire/header.go (about)

     1  package wire
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  
    10  	"github.com/danielpfeifer02/quic-go-prio-packs/crypto_turnoff"
    11  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    12  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/utils"
    13  	"github.com/danielpfeifer02/quic-go-prio-packs/quicvarint"
    14  )
    15  
    16  // ParseConnectionID parses the destination connection ID of a packet.
    17  func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) {
    18  	if len(data) == 0 {
    19  		return protocol.ConnectionID{}, io.EOF
    20  	}
    21  	if !IsLongHeaderPacket(data[0]) {
    22  		if len(data) < shortHeaderConnIDLen+1 {
    23  			return protocol.ConnectionID{}, io.EOF
    24  		}
    25  		return protocol.ParseConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil
    26  	}
    27  	if len(data) < 6 {
    28  		return protocol.ConnectionID{}, io.EOF
    29  	}
    30  	destConnIDLen := int(data[5])
    31  	if destConnIDLen > protocol.MaxConnIDLen {
    32  		return protocol.ConnectionID{}, protocol.ErrInvalidConnectionIDLen
    33  	}
    34  	if len(data) < 6+destConnIDLen {
    35  		return protocol.ConnectionID{}, io.EOF
    36  	}
    37  	return protocol.ParseConnectionID(data[6 : 6+destConnIDLen]), nil
    38  }
    39  
    40  // ParseArbitraryLenConnectionIDs parses the most general form of a Long Header packet,
    41  // using only the version-independent packet format as described in Section 5.1 of RFC 8999:
    42  // https://datatracker.ietf.org/doc/html/rfc8999#section-5.1.
    43  // This function should only be called on Long Header packets for which we don't support the version.
    44  func ParseArbitraryLenConnectionIDs(data []byte) (bytesParsed int, dest, src protocol.ArbitraryLenConnectionID, _ error) {
    45  	r := bytes.NewReader(data)
    46  	remaining := r.Len()
    47  	src, dest, err := parseArbitraryLenConnectionIDs(r)
    48  	return remaining - r.Len(), src, dest, err
    49  }
    50  
    51  func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.ArbitraryLenConnectionID, _ error) {
    52  	r.Seek(5, io.SeekStart) // skip first byte and version field
    53  	destConnIDLen, err := r.ReadByte()
    54  	if err != nil {
    55  		return nil, nil, err
    56  	}
    57  	destConnID := make(protocol.ArbitraryLenConnectionID, destConnIDLen)
    58  	if _, err := io.ReadFull(r, destConnID); err != nil {
    59  		if err == io.ErrUnexpectedEOF {
    60  			err = io.EOF
    61  		}
    62  		return nil, nil, err
    63  	}
    64  	srcConnIDLen, err := r.ReadByte()
    65  	if err != nil {
    66  		return nil, nil, err
    67  	}
    68  	srcConnID := make(protocol.ArbitraryLenConnectionID, srcConnIDLen)
    69  	if _, err := io.ReadFull(r, srcConnID); err != nil {
    70  		if err == io.ErrUnexpectedEOF {
    71  			err = io.EOF
    72  		}
    73  		return nil, nil, err
    74  	}
    75  	return destConnID, srcConnID, nil
    76  }
    77  
    78  func IsPotentialQUICPacket(firstByte byte) bool {
    79  	return firstByte&0x40 > 0
    80  }
    81  
    82  // IsLongHeaderPacket says if this is a Long Header packet
    83  func IsLongHeaderPacket(firstByte byte) bool {
    84  	return firstByte&0x80 > 0
    85  }
    86  
    87  // ParseVersion parses the QUIC version.
    88  // It should only be called for Long Header packets (Short Header packets don't contain a version number).
    89  func ParseVersion(data []byte) (protocol.Version, error) {
    90  	if len(data) < 5 {
    91  		return 0, io.EOF
    92  	}
    93  	return protocol.Version(binary.BigEndian.Uint32(data[1:5])), nil
    94  }
    95  
    96  // IsVersionNegotiationPacket says if this is a version negotiation packet
    97  func IsVersionNegotiationPacket(b []byte) bool {
    98  	if len(b) < 5 {
    99  		return false
   100  	}
   101  	return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
   102  }
   103  
   104  // Is0RTTPacket says if this is a 0-RTT packet.
   105  // A packet sent with a version we don't understand can never be a 0-RTT packet.
   106  func Is0RTTPacket(b []byte) bool {
   107  	if len(b) < 5 {
   108  		return false
   109  	}
   110  	if !IsLongHeaderPacket(b[0]) {
   111  		return false
   112  	}
   113  	version := protocol.Version(binary.BigEndian.Uint32(b[1:5]))
   114  	//nolint:exhaustive // We only need to test QUIC versions that we support.
   115  	switch version {
   116  	case protocol.Version1:
   117  		return b[0]>>4&0b11 == 0b01
   118  	case protocol.Version2:
   119  		return b[0]>>4&0b11 == 0b10
   120  	default:
   121  		return false
   122  	}
   123  }
   124  
   125  var ErrUnsupportedVersion = errors.New("unsupported version")
   126  
   127  // The Header is the version independent part of the header
   128  type Header struct {
   129  	typeByte byte
   130  	Type     protocol.PacketType
   131  
   132  	Version          protocol.Version
   133  	SrcConnectionID  protocol.ConnectionID
   134  	DestConnectionID protocol.ConnectionID
   135  
   136  	Length protocol.ByteCount
   137  
   138  	Token []byte
   139  
   140  	parsedLen protocol.ByteCount // how many bytes were read while parsing this header
   141  }
   142  
   143  // ParsePacket parses a packet.
   144  // If the packet has a long header, the packet is cut according to the length field.
   145  // If we understand the version, the packet is header up unto the packet number.
   146  // Otherwise, only the invariant part of the header is parsed.
   147  func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
   148  	if len(data) == 0 || !IsLongHeaderPacket(data[0]) {
   149  		return nil, nil, nil, errors.New("not a long header packet")
   150  	}
   151  	hdr, err := parseHeader(bytes.NewReader(data))
   152  	if err != nil {
   153  		if err == ErrUnsupportedVersion {
   154  			return hdr, nil, nil, ErrUnsupportedVersion
   155  		}
   156  		return nil, nil, nil, err
   157  	}
   158  
   159  	// NO_CRYPTO_TAG
   160  	if crypto_turnoff.CRYPTO_TURNED_OFF {
   161  		// omit cryptographic operations for prove of concept
   162  		// adapting the header length since no crypto overhead
   163  		// is present
   164  		// see: sealer.Overhead() in packet_packer.go (function starting at line 900)
   165  		// TODO is this always 16?
   166  		hdr.Length -= 16
   167  	}
   168  
   169  	if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
   170  		return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
   171  	}
   172  	packetLen := int(hdr.ParsedLen() + hdr.Length)
   173  	return hdr, data[:packetLen], data[packetLen:], nil
   174  }
   175  
   176  // ParseHeader parses the header.
   177  // For short header packets: up to the packet number.
   178  // For long header packets:
   179  // * if we understand the version: up to the packet number
   180  // * if not, only the invariant part of the header
   181  func parseHeader(b *bytes.Reader) (*Header, error) {
   182  	startLen := b.Len()
   183  	typeByte, err := b.ReadByte()
   184  	if err != nil {
   185  		return nil, err
   186  	}
   187  
   188  	h := &Header{typeByte: typeByte}
   189  	err = h.parseLongHeader(b)
   190  	h.parsedLen = protocol.ByteCount(startLen - b.Len())
   191  	return h, err
   192  }
   193  
   194  func (h *Header) parseLongHeader(b *bytes.Reader) error {
   195  	v, err := utils.BigEndian.ReadUint32(b)
   196  	if err != nil {
   197  		return err
   198  	}
   199  	h.Version = protocol.Version(v)
   200  	if h.Version != 0 && h.typeByte&0x40 == 0 {
   201  		return errors.New("not a QUIC packet")
   202  	}
   203  	destConnIDLen, err := b.ReadByte()
   204  	if err != nil {
   205  		return err
   206  	}
   207  	h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen))
   208  	if err != nil {
   209  		return err
   210  	}
   211  	srcConnIDLen, err := b.ReadByte()
   212  	if err != nil {
   213  		return err
   214  	}
   215  	h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen))
   216  	if err != nil {
   217  		return err
   218  	}
   219  	if h.Version == 0 { // version negotiation packet
   220  		return nil
   221  	}
   222  	// If we don't understand the version, we have no idea how to interpret the rest of the bytes
   223  	if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
   224  		return ErrUnsupportedVersion
   225  	}
   226  
   227  	if h.Version == protocol.Version2 {
   228  		switch h.typeByte >> 4 & 0b11 {
   229  		case 0b00:
   230  			h.Type = protocol.PacketTypeRetry
   231  		case 0b01:
   232  			h.Type = protocol.PacketTypeInitial
   233  		case 0b10:
   234  			h.Type = protocol.PacketType0RTT
   235  		case 0b11:
   236  			h.Type = protocol.PacketTypeHandshake
   237  		}
   238  	} else {
   239  		switch h.typeByte >> 4 & 0b11 {
   240  		case 0b00:
   241  			h.Type = protocol.PacketTypeInitial
   242  		case 0b01:
   243  			h.Type = protocol.PacketType0RTT
   244  		case 0b10:
   245  			h.Type = protocol.PacketTypeHandshake
   246  		case 0b11:
   247  			h.Type = protocol.PacketTypeRetry
   248  		}
   249  	}
   250  
   251  	if h.Type == protocol.PacketTypeRetry {
   252  		tokenLen := b.Len() - 16
   253  		if tokenLen <= 0 {
   254  			return io.EOF
   255  		}
   256  		h.Token = make([]byte, tokenLen)
   257  		if _, err := io.ReadFull(b, h.Token); err != nil {
   258  			return err
   259  		}
   260  		_, err := b.Seek(16, io.SeekCurrent)
   261  		return err
   262  	}
   263  
   264  	if h.Type == protocol.PacketTypeInitial {
   265  		tokenLen, err := quicvarint.Read(b)
   266  		if err != nil {
   267  			return err
   268  		}
   269  		if tokenLen > uint64(b.Len()) {
   270  			return io.EOF
   271  		}
   272  		h.Token = make([]byte, tokenLen)
   273  		if _, err := io.ReadFull(b, h.Token); err != nil {
   274  			return err
   275  		}
   276  	}
   277  
   278  	pl, err := quicvarint.Read(b)
   279  	if err != nil {
   280  		return err
   281  	}
   282  	h.Length = protocol.ByteCount(pl)
   283  	return nil
   284  }
   285  
   286  // ParsedLen returns the number of bytes that were consumed when parsing the header
   287  func (h *Header) ParsedLen() protocol.ByteCount {
   288  	return h.parsedLen
   289  }
   290  
   291  // ParseExtended parses the version dependent part of the header.
   292  // The Reader has to be set such that it points to the first byte of the header.
   293  func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) {
   294  	extHdr := h.toExtendedHeader()
   295  	reservedBitsValid, err := extHdr.parse(b, ver)
   296  	if err != nil {
   297  		return nil, err
   298  	}
   299  	if !reservedBitsValid {
   300  		return extHdr, ErrInvalidReservedBits
   301  	}
   302  	return extHdr, nil
   303  }
   304  
   305  func (h *Header) toExtendedHeader() *ExtendedHeader {
   306  	return &ExtendedHeader{Header: *h}
   307  }
   308  
   309  // PacketType is the type of the packet, for logging purposes
   310  func (h *Header) PacketType() string {
   311  	return h.Type.String()
   312  }