github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/packet_unpacker.go (about)

     1  package quic
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"time"
     7  
     8  	"github.com/danielpfeifer02/quic-go-prio-packs/crypto_turnoff"
     9  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/handshake"
    10  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    11  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr"
    12  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/wire"
    13  )
    14  
    15  type headerDecryptor interface {
    16  	DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
    17  }
    18  
    19  type headerParseError struct {
    20  	err error
    21  }
    22  
    23  func (e *headerParseError) Unwrap() error {
    24  	return e.err
    25  }
    26  
    27  func (e *headerParseError) Error() string {
    28  	return e.err.Error()
    29  }
    30  
    31  type unpackedPacket struct {
    32  	hdr             *wire.ExtendedHeader
    33  	encryptionLevel protocol.EncryptionLevel
    34  	data            []byte
    35  }
    36  
    37  // The packetUnpacker unpacks QUIC packets.
    38  type packetUnpacker struct {
    39  	cs handshake.CryptoSetup
    40  
    41  	shortHdrConnIDLen int
    42  }
    43  
    44  var _ unpacker = &packetUnpacker{}
    45  
    46  func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetUnpacker {
    47  	return &packetUnpacker{
    48  		cs:                cs,
    49  		shortHdrConnIDLen: shortHdrConnIDLen,
    50  	}
    51  }
    52  
    53  // UnpackLongHeader unpacks a Long Header packet.
    54  // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
    55  // If any other error occurred when parsing the header, the error is of type headerParseError.
    56  // If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
    57  func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) {
    58  	var encLevel protocol.EncryptionLevel
    59  	var extHdr *wire.ExtendedHeader
    60  	var decrypted []byte
    61  	//nolint:exhaustive // Retry packets can't be unpacked.
    62  	switch hdr.Type {
    63  	case protocol.PacketTypeInitial:
    64  		encLevel = protocol.EncryptionInitial
    65  		opener, err := u.cs.GetInitialOpener()
    66  		if err != nil {
    67  			return nil, err
    68  		}
    69  		extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
    70  		if err != nil {
    71  			return nil, err
    72  		}
    73  	case protocol.PacketTypeHandshake:
    74  		encLevel = protocol.EncryptionHandshake
    75  		opener, err := u.cs.GetHandshakeOpener()
    76  		if err != nil {
    77  			return nil, err
    78  		}
    79  		extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
    80  		if err != nil {
    81  			return nil, err
    82  		}
    83  	case protocol.PacketType0RTT:
    84  		encLevel = protocol.Encryption0RTT
    85  		opener, err := u.cs.Get0RTTOpener()
    86  		if err != nil {
    87  			return nil, err
    88  		}
    89  		extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
    90  		if err != nil {
    91  			return nil, err
    92  		}
    93  	default:
    94  		return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
    95  	}
    96  
    97  	if len(decrypted) == 0 {
    98  		return nil, &qerr.TransportError{
    99  			ErrorCode:    qerr.ProtocolViolation,
   100  			ErrorMessage: "empty packet",
   101  		}
   102  	}
   103  
   104  	return &unpackedPacket{
   105  		hdr:             extHdr,
   106  		encryptionLevel: encLevel,
   107  		data:            decrypted,
   108  	}, nil
   109  }
   110  
   111  func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
   112  	opener, err := u.cs.Get1RTTOpener()
   113  	if err != nil {
   114  		return 0, 0, 0, nil, err
   115  	}
   116  	pn, pnLen, kp, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data)
   117  	if err != nil {
   118  		return 0, 0, 0, nil, err
   119  	}
   120  	if len(decrypted) == 0 {
   121  		return 0, 0, 0, nil, &qerr.TransportError{
   122  			ErrorCode:    qerr.ProtocolViolation,
   123  			ErrorMessage: "empty packet",
   124  		}
   125  	}
   126  	return pn, pnLen, kp, decrypted, nil
   127  }
   128  
   129  func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, []byte, error) {
   130  	extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
   131  	// If the reserved bits are set incorrectly, we still need to continue unpacking.
   132  	// This avoids a timing side-channel, which otherwise might allow an attacker
   133  	// to gain information about the header encryption.
   134  	if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
   135  		return nil, nil, parseErr
   136  	}
   137  	extHdrLen := extHdr.ParsedLen()
   138  	extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen)
   139  	decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
   140  	if err != nil {
   141  		return nil, nil, err
   142  	}
   143  	if parseErr != nil {
   144  		return nil, nil, parseErr
   145  	}
   146  	return extHdr, decrypted, nil
   147  }
   148  
   149  func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
   150  	l, pn, pnLen, kp, parseErr := u.unpackShortHeader(opener, data)
   151  	// If the reserved bits are set incorrectly, we still need to continue unpacking.
   152  	// This avoids a timing side-channel, which otherwise might allow an attacker
   153  	// to gain information about the header encryption.
   154  	if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
   155  		return 0, 0, 0, nil, &headerParseError{parseErr}
   156  	}
   157  	pn = opener.DecodePacketNumber(pn, pnLen)
   158  	decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, pn, kp, data[:l])
   159  	if err != nil {
   160  		return 0, 0, 0, nil, err
   161  	}
   162  	return pn, pnLen, kp, decrypted, parseErr
   163  }
   164  
   165  func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int, protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, error) {
   166  	hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen
   167  
   168  	// NO_CRYPTO_TAG
   169  	// here the offset for checking if a packet is too small
   170  	// is different since no crypto overhead is in the data
   171  	offset := 0
   172  	if !crypto_turnoff.CRYPTO_TURNED_OFF {
   173  		offset = 16
   174  	}
   175  	if len(data) < hdrLen+4+offset {
   176  		return 0, 0, 0, 0, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
   177  	}
   178  
   179  	origPNBytes := make([]byte, 4)
   180  	copy(origPNBytes, data[hdrLen:hdrLen+4])
   181  	// 2. decrypt the header, assuming a 4 byte packet number
   182  	hd.DecryptHeader(
   183  		// here the (in case of no crypto) wrong offsed does not matter since DecryptHeader will not have an effect
   184  		data[hdrLen+4:hdrLen+4+offset],
   185  		&data[0],
   186  		data[hdrLen:hdrLen+4],
   187  	)
   188  	// 3. parse the header (and learn the actual length of the packet number)
   189  	l, pn, pnLen, kp, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen)
   190  	if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
   191  		return l, pn, pnLen, kp, parseErr
   192  	}
   193  	// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
   194  	if pnLen != protocol.PacketNumberLen4 {
   195  		copy(data[hdrLen+int(pnLen):hdrLen+4], origPNBytes[int(pnLen):])
   196  	}
   197  	return l, pn, pnLen, kp, parseErr
   198  }
   199  
   200  // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
   201  func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
   202  	extHdr, err := unpackLongHeader(hd, hdr, data, v)
   203  	if err != nil && err != wire.ErrInvalidReservedBits {
   204  		return nil, &headerParseError{err: err}
   205  	}
   206  	return extHdr, err
   207  }
   208  
   209  func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
   210  	r := bytes.NewReader(data)
   211  
   212  	// NO_CRYPTO_TAG
   213  	// here the offset for checking if a packet is too small
   214  	// is different since no crypto overhead is in the data
   215  	offset := 0
   216  	if !crypto_turnoff.CRYPTO_TURNED_OFF {
   217  		offset = 16
   218  	}
   219  
   220  	hdrLen := hdr.ParsedLen()
   221  	if protocol.ByteCount(len(data)) < hdrLen+4+protocol.ByteCount(offset) {
   222  		//nolint:stylecheck
   223  		return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen)
   224  	}
   225  	// The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it.
   226  	// 1. save a copy of the 4 bytes
   227  	origPNBytes := make([]byte, 4)
   228  	copy(origPNBytes, data[hdrLen:hdrLen+4])
   229  	// 2. decrypt the header, assuming a 4 byte packet number
   230  	hd.DecryptHeader(
   231  		data[hdrLen+4:hdrLen+4+protocol.ByteCount(offset)],
   232  		&data[0],
   233  		data[hdrLen:hdrLen+4],
   234  	)
   235  	// 3. parse the header (and learn the actual length of the packet number)
   236  	extHdr, parseErr := hdr.ParseExtended(r, v)
   237  	if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
   238  		return nil, parseErr
   239  	}
   240  	// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
   241  	if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
   242  		copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
   243  	}
   244  	return extHdr, parseErr
   245  }