github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/internal/wire/extended_header.go (about)

     1  package wire
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  
    10  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    11  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/utils"
    12  	"github.com/danielpfeifer02/quic-go-prio-packs/quicvarint"
    13  )
    14  
    15  // ErrInvalidReservedBits is returned when the reserved bits are incorrect.
    16  // When this error is returned, parsing continues, and an ExtendedHeader is returned.
    17  // This is necessary because we need to decrypt the packet in that case,
    18  // in order to avoid a timing side-channel.
    19  var ErrInvalidReservedBits = errors.New("invalid reserved bits")
    20  
    21  // ExtendedHeader is the header of a QUIC packet.
    22  type ExtendedHeader struct {
    23  	Header
    24  
    25  	typeByte byte
    26  
    27  	KeyPhase protocol.KeyPhaseBit
    28  
    29  	PacketNumberLen protocol.PacketNumberLen
    30  	PacketNumber    protocol.PacketNumber
    31  
    32  	parsedLen protocol.ByteCount
    33  }
    34  
    35  func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.Version) (bool /* reserved bits valid */, error) {
    36  	startLen := b.Len()
    37  	// read the (now unencrypted) first byte
    38  	var err error
    39  	h.typeByte, err = b.ReadByte()
    40  	if err != nil {
    41  		return false, err
    42  	}
    43  	if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil {
    44  		return false, err
    45  	}
    46  	reservedBitsValid, err := h.parseLongHeader(b, v)
    47  	if err != nil {
    48  		return false, err
    49  	}
    50  	h.parsedLen = protocol.ByteCount(startLen - b.Len())
    51  	return reservedBitsValid, err
    52  }
    53  
    54  func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.Version) (bool /* reserved bits valid */, error) {
    55  	if err := h.readPacketNumber(b); err != nil {
    56  		return false, err
    57  	}
    58  	if h.typeByte&0xc != 0 {
    59  		return false, nil
    60  	}
    61  	return true, nil
    62  }
    63  
    64  func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
    65  	h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
    66  	switch h.PacketNumberLen {
    67  	case protocol.PacketNumberLen1:
    68  		n, err := b.ReadByte()
    69  		if err != nil {
    70  			return err
    71  		}
    72  		h.PacketNumber = protocol.PacketNumber(n)
    73  	case protocol.PacketNumberLen2:
    74  		n, err := utils.BigEndian.ReadUint16(b)
    75  		if err != nil {
    76  			return err
    77  		}
    78  		h.PacketNumber = protocol.PacketNumber(n)
    79  	case protocol.PacketNumberLen3:
    80  		n, err := utils.BigEndian.ReadUint24(b)
    81  		if err != nil {
    82  			return err
    83  		}
    84  		h.PacketNumber = protocol.PacketNumber(n)
    85  	case protocol.PacketNumberLen4:
    86  		n, err := utils.BigEndian.ReadUint32(b)
    87  		if err != nil {
    88  			return err
    89  		}
    90  		h.PacketNumber = protocol.PacketNumber(n)
    91  	default:
    92  		return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
    93  	}
    94  	return nil
    95  }
    96  
    97  // Append appends the Header.
    98  func (h *ExtendedHeader) Append(b []byte, v protocol.Version) ([]byte, error) {
    99  	if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
   100  		return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
   101  	}
   102  	if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
   103  		return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
   104  	}
   105  
   106  	var packetType uint8
   107  	if v == protocol.Version2 {
   108  		//nolint:exhaustive
   109  		switch h.Type {
   110  		case protocol.PacketTypeInitial:
   111  			packetType = 0b01
   112  		case protocol.PacketType0RTT:
   113  			packetType = 0b10
   114  		case protocol.PacketTypeHandshake:
   115  			packetType = 0b11
   116  		case protocol.PacketTypeRetry:
   117  			packetType = 0b00
   118  		}
   119  	} else {
   120  		//nolint:exhaustive
   121  		switch h.Type {
   122  		case protocol.PacketTypeInitial:
   123  			packetType = 0b00
   124  		case protocol.PacketType0RTT:
   125  			packetType = 0b01
   126  		case protocol.PacketTypeHandshake:
   127  			packetType = 0b10
   128  		case protocol.PacketTypeRetry:
   129  			packetType = 0b11
   130  		}
   131  	}
   132  	firstByte := 0xc0 | packetType<<4
   133  	if h.Type != protocol.PacketTypeRetry {
   134  		// Retry packets don't have a packet number
   135  		firstByte |= uint8(h.PacketNumberLen - 1)
   136  	}
   137  
   138  	b = append(b, firstByte)
   139  	b = append(b, make([]byte, 4)...)
   140  	binary.BigEndian.PutUint32(b[len(b)-4:], uint32(h.Version))
   141  	b = append(b, uint8(h.DestConnectionID.Len()))
   142  	b = append(b, h.DestConnectionID.Bytes()...)
   143  	b = append(b, uint8(h.SrcConnectionID.Len()))
   144  	b = append(b, h.SrcConnectionID.Bytes()...)
   145  
   146  	//nolint:exhaustive
   147  	switch h.Type {
   148  	case protocol.PacketTypeRetry:
   149  		b = append(b, h.Token...)
   150  		return b, nil
   151  	case protocol.PacketTypeInitial:
   152  		b = quicvarint.Append(b, uint64(len(h.Token)))
   153  		b = append(b, h.Token...)
   154  	}
   155  	b = quicvarint.AppendWithLen(b, uint64(h.Length), 2)
   156  	return appendPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
   157  }
   158  
   159  // ParsedLen returns the number of bytes that were consumed when parsing the header
   160  func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
   161  	return h.parsedLen
   162  }
   163  
   164  // GetLength determines the length of the Header.
   165  func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount {
   166  	length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
   167  	if h.Type == protocol.PacketTypeInitial {
   168  		length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
   169  	}
   170  	return length
   171  }
   172  
   173  // Log logs the Header
   174  func (h *ExtendedHeader) Log(logger utils.Logger) {
   175  	var token string
   176  	if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
   177  		if len(h.Token) == 0 {
   178  			token = "Token: (empty), "
   179  		} else {
   180  			token = fmt.Sprintf("Token: %#x, ", h.Token)
   181  		}
   182  		if h.Type == protocol.PacketTypeRetry {
   183  			logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version)
   184  			return
   185  		}
   186  	}
   187  	logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
   188  }
   189  
   190  func appendPacketNumber(b []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) ([]byte, error) {
   191  	switch pnLen {
   192  	case protocol.PacketNumberLen1:
   193  		b = append(b, uint8(pn))
   194  	case protocol.PacketNumberLen2:
   195  		buf := make([]byte, 2)
   196  		binary.BigEndian.PutUint16(buf, uint16(pn))
   197  		b = append(b, buf...)
   198  	case protocol.PacketNumberLen3:
   199  		buf := make([]byte, 4)
   200  		binary.BigEndian.PutUint32(buf, uint32(pn))
   201  		b = append(b, buf[1:]...)
   202  	case protocol.PacketNumberLen4:
   203  		buf := make([]byte, 4)
   204  		binary.BigEndian.PutUint32(buf, uint32(pn))
   205  		b = append(b, buf...)
   206  	default:
   207  		return nil, fmt.Errorf("invalid packet number length: %d", pnLen)
   208  	}
   209  	return b, nil
   210  }