github.com/tumi8/quic-go@v0.37.4-tum/noninternal/wire/header.go (about)

     1  package wire
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  
    10  	"github.com/tumi8/quic-go/noninternal/protocol"
    11  	"github.com/tumi8/quic-go/noninternal/utils"
    12  	"github.com/tumi8/quic-go/quicvarint"
    13  )
    14  
    15  // ParseConnectionID parses the destination connection ID of a packet.
    16  func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) {
    17  	if len(data) == 0 {
    18  		return protocol.ConnectionID{}, io.EOF
    19  	}
    20  	if !IsLongHeaderPacket(data[0]) {
    21  		if len(data) < shortHeaderConnIDLen+1 {
    22  			return protocol.ConnectionID{}, io.EOF
    23  		}
    24  		return protocol.ParseConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil
    25  	}
    26  	if len(data) < 6 {
    27  		return protocol.ConnectionID{}, io.EOF
    28  	}
    29  	destConnIDLen := int(data[5])
    30  	if destConnIDLen > protocol.MaxConnIDLen {
    31  		return protocol.ConnectionID{}, protocol.ErrInvalidConnectionIDLen
    32  	}
    33  	if len(data) < 6+destConnIDLen {
    34  		return protocol.ConnectionID{}, io.EOF
    35  	}
    36  	return protocol.ParseConnectionID(data[6 : 6+destConnIDLen]), nil
    37  }
    38  
    39  // ParseArbitraryLenConnectionIDs parses the most general form of a Long Header packet,
    40  // using only the version-independent packet format as described in Section 5.1 of RFC 8999:
    41  // https://datatracker.ietf.org/doc/html/rfc8999#section-5.1.
    42  // This function should only be called on Long Header packets for which we don't support the version.
    43  func ParseArbitraryLenConnectionIDs(data []byte) (bytesParsed int, dest, src protocol.ArbitraryLenConnectionID, _ error) {
    44  	r := bytes.NewReader(data)
    45  	remaining := r.Len()
    46  	src, dest, err := parseArbitraryLenConnectionIDs(r)
    47  	return remaining - r.Len(), src, dest, err
    48  }
    49  
    50  func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.ArbitraryLenConnectionID, _ error) {
    51  	r.Seek(5, io.SeekStart) // skip first byte and version field
    52  	destConnIDLen, err := r.ReadByte()
    53  	if err != nil {
    54  		return nil, nil, err
    55  	}
    56  	destConnID := make(protocol.ArbitraryLenConnectionID, destConnIDLen)
    57  	if _, err := io.ReadFull(r, destConnID); err != nil {
    58  		if err == io.ErrUnexpectedEOF {
    59  			err = io.EOF
    60  		}
    61  		return nil, nil, err
    62  	}
    63  	srcConnIDLen, err := r.ReadByte()
    64  	if err != nil {
    65  		return nil, nil, err
    66  	}
    67  	srcConnID := make(protocol.ArbitraryLenConnectionID, srcConnIDLen)
    68  	if _, err := io.ReadFull(r, srcConnID); err != nil {
    69  		if err == io.ErrUnexpectedEOF {
    70  			err = io.EOF
    71  		}
    72  		return nil, nil, err
    73  	}
    74  	return destConnID, srcConnID, nil
    75  }
    76  
    77  // IsLongHeaderPacket says if this is a Long Header packet
    78  func IsLongHeaderPacket(firstByte byte) bool {
    79  	return firstByte&0x80 > 0
    80  }
    81  
    82  // ParseVersion parses the QUIC version.
    83  // It should only be called for Long Header packets (Short Header packets don't contain a version number).
    84  func ParseVersion(data []byte) (protocol.VersionNumber, error) {
    85  	if len(data) < 5 {
    86  		return 0, io.EOF
    87  	}
    88  	return protocol.VersionNumber(binary.BigEndian.Uint32(data[1:5])), nil
    89  }
    90  
    91  // IsVersionNegotiationPacket says if this is a version negotiation packet
    92  func IsVersionNegotiationPacket(b []byte) bool {
    93  	if len(b) < 5 {
    94  		return false
    95  	}
    96  	return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
    97  }
    98  
    99  // Is0RTTPacket says if this is a 0-RTT packet.
   100  // A packet sent with a version we don't understand can never be a 0-RTT packet.
   101  func Is0RTTPacket(b []byte) bool {
   102  	if len(b) < 5 {
   103  		return false
   104  	}
   105  	if !IsLongHeaderPacket(b[0]) {
   106  		return false
   107  	}
   108  	version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5]))
   109  	//nolint:exhaustive // We only need to test QUIC versions that we support.
   110  	switch version {
   111  	case protocol.Version1:
   112  		return b[0]>>4&0b11 == 0b01
   113  	case protocol.Version2:
   114  		return b[0]>>4&0b11 == 0b10
   115  	default:
   116  		return false
   117  	}
   118  }
   119  
   120  var ErrUnsupportedVersion = errors.New("unsupported version")
   121  
   122  // The Header is the version independent part of the header
   123  type Header struct {
   124  	typeByte byte
   125  	Type     protocol.PacketType
   126  
   127  	Version          protocol.VersionNumber
   128  	SrcConnectionID  protocol.ConnectionID
   129  	DestConnectionID protocol.ConnectionID
   130  
   131  	Length protocol.ByteCount
   132  
   133  	Token []byte
   134  
   135  	parsedLen protocol.ByteCount // how many bytes were read while parsing this header
   136  }
   137  
   138  // ParsePacket parses a packet.
   139  // If the packet has a long header, the packet is cut according to the length field.
   140  // If we understand the version, the packet is header up unto the packet number.
   141  // Otherwise, only the invariant part of the header is parsed.
   142  func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
   143  	if len(data) == 0 || !IsLongHeaderPacket(data[0]) {
   144  		return nil, nil, nil, errors.New("not a long header packet")
   145  	}
   146  	hdr, err := parseHeader(bytes.NewReader(data))
   147  	if err != nil {
   148  		if err == ErrUnsupportedVersion {
   149  			return hdr, nil, nil, ErrUnsupportedVersion
   150  		}
   151  		return nil, nil, nil, err
   152  	}
   153  	if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
   154  		return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
   155  	}
   156  	packetLen := int(hdr.ParsedLen() + hdr.Length)
   157  	return hdr, data[:packetLen], data[packetLen:], nil
   158  }
   159  
   160  // ParseHeader parses the header.
   161  // For short header packets: up to the packet number.
   162  // For long header packets:
   163  // * if we understand the version: up to the packet number
   164  // * if not, only the invariant part of the header
   165  func parseHeader(b *bytes.Reader) (*Header, error) {
   166  	startLen := b.Len()
   167  	typeByte, err := b.ReadByte()
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	h := &Header{typeByte: typeByte}
   173  	err = h.parseLongHeader(b)
   174  	h.parsedLen = protocol.ByteCount(startLen - b.Len())
   175  	return h, err
   176  }
   177  
   178  func (h *Header) parseLongHeader(b *bytes.Reader) error {
   179  	v, err := utils.BigEndian.ReadUint32(b)
   180  	if err != nil {
   181  		return err
   182  	}
   183  	h.Version = protocol.VersionNumber(v)
   184  	if h.Version != 0 && h.typeByte&0x40 == 0 {
   185  		return errors.New("not a QUIC packet")
   186  	}
   187  	destConnIDLen, err := b.ReadByte()
   188  	if err != nil {
   189  		return err
   190  	}
   191  	h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen))
   192  	if err != nil {
   193  		return err
   194  	}
   195  	srcConnIDLen, err := b.ReadByte()
   196  	if err != nil {
   197  		return err
   198  	}
   199  	h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen))
   200  	if err != nil {
   201  		return err
   202  	}
   203  	if h.Version == 0 { // version negotiation packet
   204  		return nil
   205  	}
   206  	// If we don't understand the version, we have no idea how to interpret the rest of the bytes
   207  	if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
   208  		return ErrUnsupportedVersion
   209  	}
   210  
   211  	if h.Version == protocol.Version2 {
   212  		switch h.typeByte >> 4 & 0b11 {
   213  		case 0b00:
   214  			h.Type = protocol.PacketTypeRetry
   215  		case 0b01:
   216  			h.Type = protocol.PacketTypeInitial
   217  		case 0b10:
   218  			h.Type = protocol.PacketType0RTT
   219  		case 0b11:
   220  			h.Type = protocol.PacketTypeHandshake
   221  		}
   222  	} else {
   223  		switch h.typeByte >> 4 & 0b11 {
   224  		case 0b00:
   225  			h.Type = protocol.PacketTypeInitial
   226  		case 0b01:
   227  			h.Type = protocol.PacketType0RTT
   228  		case 0b10:
   229  			h.Type = protocol.PacketTypeHandshake
   230  		case 0b11:
   231  			h.Type = protocol.PacketTypeRetry
   232  		}
   233  	}
   234  
   235  	if h.Type == protocol.PacketTypeRetry {
   236  		tokenLen := b.Len() - 16
   237  		if tokenLen <= 0 {
   238  			return io.EOF
   239  		}
   240  		h.Token = make([]byte, tokenLen)
   241  		if _, err := io.ReadFull(b, h.Token); err != nil {
   242  			return err
   243  		}
   244  		_, err := b.Seek(16, io.SeekCurrent)
   245  		return err
   246  	}
   247  
   248  	if h.Type == protocol.PacketTypeInitial {
   249  		tokenLen, err := quicvarint.Read(b)
   250  		if err != nil {
   251  			return err
   252  		}
   253  		if tokenLen > uint64(b.Len()) {
   254  			return io.EOF
   255  		}
   256  		h.Token = make([]byte, tokenLen)
   257  		if _, err := io.ReadFull(b, h.Token); err != nil {
   258  			return err
   259  		}
   260  	}
   261  
   262  	pl, err := quicvarint.Read(b)
   263  	if err != nil {
   264  		return err
   265  	}
   266  	h.Length = protocol.ByteCount(pl)
   267  	return nil
   268  }
   269  
   270  // ParsedLen returns the number of bytes that were consumed when parsing the header
   271  func (h *Header) ParsedLen() protocol.ByteCount {
   272  	return h.parsedLen
   273  }
   274  
   275  // ParseExtended parses the version dependent part of the header.
   276  // The Reader has to be set such that it points to the first byte of the header.
   277  func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
   278  	extHdr := h.toExtendedHeader()
   279  	reservedBitsValid, err := extHdr.parse(b, ver)
   280  	if err != nil {
   281  		return nil, err
   282  	}
   283  	if !reservedBitsValid {
   284  		return extHdr, ErrInvalidReservedBits
   285  	}
   286  	return extHdr, nil
   287  }
   288  
   289  func (h *Header) toExtendedHeader() *ExtendedHeader {
   290  	return &ExtendedHeader{Header: *h}
   291  }
   292  
   293  // PacketType is the type of the packet, for logging purposes
   294  func (h *Header) PacketType() string {
   295  	return h.Type.String()
   296  }