github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/internal/wire/header.go (about)

     1  package wire
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  
    10  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol"
    11  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils"
    12  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/quicvarint"
    13  )
    14  
    15  // ParseConnectionID parses the destination connection ID of a packet.
    16  // It uses the data slice for the connection ID.
    17  // That means that the connection ID must not be used after the packet buffer is released.
    18  func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) {
    19  	if len(data) == 0 {
    20  		return nil, io.EOF
    21  	}
    22  	isLongHeader := data[0]&0x80 > 0
    23  	if !isLongHeader {
    24  		if len(data) < shortHeaderConnIDLen+1 {
    25  			return nil, io.EOF
    26  		}
    27  		return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil
    28  	}
    29  	if len(data) < 6 {
    30  		return nil, io.EOF
    31  	}
    32  	destConnIDLen := int(data[5])
    33  	if len(data) < 6+destConnIDLen {
    34  		return nil, io.EOF
    35  	}
    36  	return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil
    37  }
    38  
    39  // IsVersionNegotiationPacket says if this is a version negotiation packet
    40  func IsVersionNegotiationPacket(b []byte) bool {
    41  	if len(b) < 5 {
    42  		return false
    43  	}
    44  	return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
    45  }
    46  
    47  // Is0RTTPacket says if this is a 0-RTT packet.
    48  // A packet sent with a version we don't understand can never be a 0-RTT packet.
    49  func Is0RTTPacket(b []byte) bool {
    50  	if len(b) < 5 {
    51  		return false
    52  	}
    53  	if b[0]&0x80 == 0 {
    54  		return false
    55  	}
    56  	if !protocol.IsSupportedVersion(protocol.SupportedVersions, protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5]))) {
    57  		return false
    58  	}
    59  	return b[0]&0x30>>4 == 0x1
    60  }
    61  
    62  var ErrUnsupportedVersion = errors.New("unsupported version")
    63  
    64  // The Header is the version independent part of the header
    65  type Header struct {
    66  	IsLongHeader bool
    67  	typeByte     byte
    68  	Type         protocol.PacketType
    69  
    70  	Version          protocol.VersionNumber
    71  	SrcConnectionID  protocol.ConnectionID
    72  	DestConnectionID protocol.ConnectionID
    73  
    74  	Length protocol.ByteCount
    75  
    76  	Token []byte
    77  
    78  	parsedLen protocol.ByteCount // how many bytes were read while parsing this header
    79  }
    80  
    81  // ParsePacket parses a packet.
    82  // If the packet has a long header, the packet is cut according to the length field.
    83  // If we understand the version, the packet is header up unto the packet number.
    84  // Otherwise, only the invariant part of the header is parsed.
    85  func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* packet data */, []byte /* rest */, error) {
    86  	hdr, err := parseHeader(bytes.NewReader(data), shortHeaderConnIDLen)
    87  	if err != nil {
    88  		if err == ErrUnsupportedVersion {
    89  			return hdr, nil, nil, ErrUnsupportedVersion
    90  		}
    91  		return nil, nil, nil, err
    92  	}
    93  	var rest []byte
    94  	if hdr.IsLongHeader {
    95  		if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
    96  			return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
    97  		}
    98  		packetLen := int(hdr.ParsedLen() + hdr.Length)
    99  		rest = data[packetLen:]
   100  		data = data[:packetLen]
   101  	}
   102  	return hdr, data, rest, nil
   103  }
   104  
   105  // ParseHeader parses the header.
   106  // For short header packets: up to the packet number.
   107  // For long header packets:
   108  // * if we understand the version: up to the packet number
   109  // * if not, only the invariant part of the header
   110  func parseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
   111  	startLen := b.Len()
   112  	h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
   113  	if err != nil {
   114  		return h, err
   115  	}
   116  	h.parsedLen = protocol.ByteCount(startLen - b.Len())
   117  	return h, err
   118  }
   119  
   120  func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
   121  	typeByte, err := b.ReadByte()
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	h := &Header{
   127  		typeByte:     typeByte,
   128  		IsLongHeader: typeByte&0x80 > 0,
   129  	}
   130  
   131  	if !h.IsLongHeader {
   132  		if h.typeByte&0x40 == 0 {
   133  			return nil, errors.New("not a QUIC packet")
   134  		}
   135  		if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
   136  			return nil, err
   137  		}
   138  		return h, nil
   139  	}
   140  	return h, h.parseLongHeader(b)
   141  }
   142  
   143  func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
   144  	var err error
   145  	h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
   146  	return err
   147  }
   148  
   149  func (h *Header) parseLongHeader(b *bytes.Reader) error {
   150  	v, err := utils.BigEndian.ReadUint32(b)
   151  	if err != nil {
   152  		return err
   153  	}
   154  	h.Version = protocol.VersionNumber(v)
   155  	if h.Version != 0 && h.typeByte&0x40 == 0 {
   156  		return errors.New("not a QUIC packet")
   157  	}
   158  	destConnIDLen, err := b.ReadByte()
   159  	if err != nil {
   160  		return err
   161  	}
   162  	h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen))
   163  	if err != nil {
   164  		return err
   165  	}
   166  	srcConnIDLen, err := b.ReadByte()
   167  	if err != nil {
   168  		return err
   169  	}
   170  	h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen))
   171  	if err != nil {
   172  		return err
   173  	}
   174  	if h.Version == 0 { // version negotiation packet
   175  		return nil
   176  	}
   177  	// If we don't understand the version, we have no idea how to interpret the rest of the bytes
   178  	if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
   179  		return ErrUnsupportedVersion
   180  	}
   181  
   182  	switch (h.typeByte & 0x30) >> 4 {
   183  	case 0x0:
   184  		h.Type = protocol.PacketTypeInitial
   185  	case 0x1:
   186  		h.Type = protocol.PacketType0RTT
   187  	case 0x2:
   188  		h.Type = protocol.PacketTypeHandshake
   189  	case 0x3:
   190  		h.Type = protocol.PacketTypeRetry
   191  	}
   192  
   193  	if h.Type == protocol.PacketTypeRetry {
   194  		tokenLen := b.Len() - 16
   195  		if tokenLen <= 0 {
   196  			return io.EOF
   197  		}
   198  		h.Token = make([]byte, tokenLen)
   199  		if _, err := io.ReadFull(b, h.Token); err != nil {
   200  			return err
   201  		}
   202  		_, err := b.Seek(16, io.SeekCurrent)
   203  		return err
   204  	}
   205  
   206  	if h.Type == protocol.PacketTypeInitial {
   207  		tokenLen, err := quicvarint.Read(b)
   208  		if err != nil {
   209  			return err
   210  		}
   211  		if tokenLen > uint64(b.Len()) {
   212  			return io.EOF
   213  		}
   214  		h.Token = make([]byte, tokenLen)
   215  		if _, err := io.ReadFull(b, h.Token); err != nil {
   216  			return err
   217  		}
   218  	}
   219  
   220  	pl, err := quicvarint.Read(b)
   221  	if err != nil {
   222  		return err
   223  	}
   224  	h.Length = protocol.ByteCount(pl)
   225  	return nil
   226  }
   227  
   228  // ParsedLen returns the number of bytes that were consumed when parsing the header
   229  func (h *Header) ParsedLen() protocol.ByteCount {
   230  	return h.parsedLen
   231  }
   232  
   233  // ParseExtended parses the version dependent part of the header.
   234  // The Reader has to be set such that it points to the first byte of the header.
   235  func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
   236  	extHdr := h.toExtendedHeader()
   237  	reservedBitsValid, err := extHdr.parse(b, ver)
   238  	if err != nil {
   239  		return nil, err
   240  	}
   241  	if !reservedBitsValid {
   242  		return extHdr, ErrInvalidReservedBits
   243  	}
   244  	return extHdr, nil
   245  }
   246  
   247  func (h *Header) toExtendedHeader() *ExtendedHeader {
   248  	return &ExtendedHeader{Header: *h}
   249  }
   250  
   251  // PacketType is the type of the packet, for logging purposes
   252  func (h *Header) PacketType() string {
   253  	if h.IsLongHeader {
   254  		return h.Type.String()
   255  	}
   256  	return "1-RTT"
   257  }