github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/internal/wire/extended_header.go (about)

     1  package wire
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  
     9  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol"
    10  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils"
    11  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/quicvarint"
    12  )
    13  
    14  // ErrInvalidReservedBits is returned when the reserved bits are incorrect.
    15  // When this error is returned, parsing continues, and an ExtendedHeader is returned.
    16  // This is necessary because we need to decrypt the packet in that case,
    17  // in order to avoid a timing side-channel.
    18  var ErrInvalidReservedBits = errors.New("invalid reserved bits")
    19  
    20  // ExtendedHeader is the header of a QUIC packet.
    21  type ExtendedHeader struct {
    22  	Header
    23  
    24  	typeByte byte
    25  
    26  	KeyPhase protocol.KeyPhaseBit
    27  
    28  	PacketNumberLen protocol.PacketNumberLen
    29  	PacketNumber    protocol.PacketNumber
    30  
    31  	parsedLen protocol.ByteCount
    32  }
    33  
    34  func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) {
    35  	startLen := b.Len()
    36  	// read the (now unencrypted) first byte
    37  	var err error
    38  	h.typeByte, err = b.ReadByte()
    39  	if err != nil {
    40  		return false, err
    41  	}
    42  	if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil {
    43  		return false, err
    44  	}
    45  	var reservedBitsValid bool
    46  	if h.IsLongHeader {
    47  		reservedBitsValid, err = h.parseLongHeader(b, v)
    48  	} else {
    49  		reservedBitsValid, err = h.parseShortHeader(b, v)
    50  	}
    51  	if err != nil {
    52  		return false, err
    53  	}
    54  	h.parsedLen = protocol.ByteCount(startLen - b.Len())
    55  	return reservedBitsValid, err
    56  }
    57  
    58  func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
    59  	if err := h.readPacketNumber(b); err != nil {
    60  		return false, err
    61  	}
    62  	if h.typeByte&0xc != 0 {
    63  		return false, nil
    64  	}
    65  	return true, nil
    66  }
    67  
    68  func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
    69  	h.KeyPhase = protocol.KeyPhaseZero
    70  	if h.typeByte&0x4 > 0 {
    71  		h.KeyPhase = protocol.KeyPhaseOne
    72  	}
    73  
    74  	if err := h.readPacketNumber(b); err != nil {
    75  		return false, err
    76  	}
    77  	if h.typeByte&0x18 != 0 {
    78  		return false, nil
    79  	}
    80  	return true, nil
    81  }
    82  
    83  func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
    84  	h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
    85  	switch h.PacketNumberLen {
    86  	case protocol.PacketNumberLen1:
    87  		n, err := b.ReadByte()
    88  		if err != nil {
    89  			return err
    90  		}
    91  		h.PacketNumber = protocol.PacketNumber(n)
    92  	case protocol.PacketNumberLen2:
    93  		n, err := utils.BigEndian.ReadUint16(b)
    94  		if err != nil {
    95  			return err
    96  		}
    97  		h.PacketNumber = protocol.PacketNumber(n)
    98  	case protocol.PacketNumberLen3:
    99  		n, err := utils.BigEndian.ReadUint24(b)
   100  		if err != nil {
   101  			return err
   102  		}
   103  		h.PacketNumber = protocol.PacketNumber(n)
   104  	case protocol.PacketNumberLen4:
   105  		n, err := utils.BigEndian.ReadUint32(b)
   106  		if err != nil {
   107  			return err
   108  		}
   109  		h.PacketNumber = protocol.PacketNumber(n)
   110  	default:
   111  		return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
   112  	}
   113  	return nil
   114  }
   115  
   116  // Write writes the Header.
   117  func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
   118  	if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
   119  		return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
   120  	}
   121  	if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
   122  		return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
   123  	}
   124  	if h.IsLongHeader {
   125  		return h.writeLongHeader(b, ver)
   126  	}
   127  	return h.writeShortHeader(b, ver)
   128  }
   129  
   130  func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, _ protocol.VersionNumber) error {
   131  	var packetType uint8
   132  	//nolint:exhaustive
   133  	switch h.Type {
   134  	case protocol.PacketTypeInitial:
   135  		packetType = 0x0
   136  	case protocol.PacketType0RTT:
   137  		packetType = 0x1
   138  	case protocol.PacketTypeHandshake:
   139  		packetType = 0x2
   140  	case protocol.PacketTypeRetry:
   141  		packetType = 0x3
   142  	}
   143  	firstByte := 0xc0 | packetType<<4
   144  	if h.Type != protocol.PacketTypeRetry {
   145  		// Retry packets don't have a packet number
   146  		firstByte |= uint8(h.PacketNumberLen - 1)
   147  	}
   148  
   149  	b.WriteByte(firstByte)
   150  	utils.BigEndian.WriteUint32(b, uint32(h.Version))
   151  	b.WriteByte(uint8(h.DestConnectionID.Len()))
   152  	b.Write(h.DestConnectionID.Bytes())
   153  	b.WriteByte(uint8(h.SrcConnectionID.Len()))
   154  	b.Write(h.SrcConnectionID.Bytes())
   155  
   156  	//nolint:exhaustive
   157  	switch h.Type {
   158  	case protocol.PacketTypeRetry:
   159  		b.Write(h.Token)
   160  		return nil
   161  	case protocol.PacketTypeInitial:
   162  		quicvarint.Write(b, uint64(len(h.Token)))
   163  		b.Write(h.Token)
   164  	}
   165  	quicvarint.WriteWithLen(b, uint64(h.Length), 2)
   166  	return h.writePacketNumber(b)
   167  }
   168  
   169  func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error {
   170  	typeByte := 0x40 | uint8(h.PacketNumberLen-1)
   171  	if h.KeyPhase == protocol.KeyPhaseOne {
   172  		typeByte |= byte(1 << 2)
   173  	}
   174  
   175  	b.WriteByte(typeByte)
   176  	b.Write(h.DestConnectionID.Bytes())
   177  	return h.writePacketNumber(b)
   178  }
   179  
   180  func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
   181  	switch h.PacketNumberLen {
   182  	case protocol.PacketNumberLen1:
   183  		b.WriteByte(uint8(h.PacketNumber))
   184  	case protocol.PacketNumberLen2:
   185  		utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
   186  	case protocol.PacketNumberLen3:
   187  		utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber))
   188  	case protocol.PacketNumberLen4:
   189  		utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
   190  	default:
   191  		return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
   192  	}
   193  	return nil
   194  }
   195  
   196  // ParsedLen returns the number of bytes that were consumed when parsing the header
   197  func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
   198  	return h.parsedLen
   199  }
   200  
   201  // GetLength determines the length of the Header.
   202  func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount {
   203  	if h.IsLongHeader {
   204  		length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
   205  		if h.Type == protocol.PacketTypeInitial {
   206  			length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
   207  		}
   208  		return length
   209  	}
   210  
   211  	length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
   212  	length += protocol.ByteCount(h.PacketNumberLen)
   213  	return length
   214  }
   215  
   216  // Log logs the Header
   217  func (h *ExtendedHeader) Log(logger utils.Logger) {
   218  	if h.IsLongHeader {
   219  		var token string
   220  		if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
   221  			if len(h.Token) == 0 {
   222  				token = "Token: (empty), "
   223  			} else {
   224  				token = fmt.Sprintf("Token: %#x, ", h.Token)
   225  			}
   226  			if h.Type == protocol.PacketTypeRetry {
   227  				logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version)
   228  				return
   229  			}
   230  		}
   231  		logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
   232  	} else {
   233  		logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
   234  	}
   235  }