github.com/sagernet/quic-go@v0.43.1-beta.1/ech/packet_unpacker.go (about)

     1  package quic
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"time"
     7  
     8  	"github.com/sagernet/quic-go/internal/handshake_ech"
     9  	"github.com/sagernet/quic-go/internal/protocol"
    10  	"github.com/sagernet/quic-go/internal/qerr"
    11  	"github.com/sagernet/quic-go/internal/wire"
    12  )
    13  
    14  type headerDecryptor interface {
    15  	DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
    16  }
    17  
    18  type headerParseError struct {
    19  	err error
    20  }
    21  
    22  func (e *headerParseError) Unwrap() error {
    23  	return e.err
    24  }
    25  
    26  func (e *headerParseError) Error() string {
    27  	return e.err.Error()
    28  }
    29  
    30  type unpackedPacket struct {
    31  	hdr             *wire.ExtendedHeader
    32  	encryptionLevel protocol.EncryptionLevel
    33  	data            []byte
    34  }
    35  
    36  // The packetUnpacker unpacks QUIC packets.
    37  type packetUnpacker struct {
    38  	cs handshake.CryptoSetup
    39  
    40  	shortHdrConnIDLen int
    41  }
    42  
    43  var _ unpacker = &packetUnpacker{}
    44  
    45  func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetUnpacker {
    46  	return &packetUnpacker{
    47  		cs:                cs,
    48  		shortHdrConnIDLen: shortHdrConnIDLen,
    49  	}
    50  }
    51  
    52  // UnpackLongHeader unpacks a Long Header packet.
    53  // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
    54  // If any other error occurred when parsing the header, the error is of type headerParseError.
    55  // If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
    56  func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) {
    57  	var encLevel protocol.EncryptionLevel
    58  	var extHdr *wire.ExtendedHeader
    59  	var decrypted []byte
    60  	//nolint:exhaustive // Retry packets can't be unpacked.
    61  	switch hdr.Type {
    62  	case protocol.PacketTypeInitial:
    63  		encLevel = protocol.EncryptionInitial
    64  		opener, err := u.cs.GetInitialOpener()
    65  		if err != nil {
    66  			return nil, err
    67  		}
    68  		extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
    69  		if err != nil {
    70  			return nil, err
    71  		}
    72  	case protocol.PacketTypeHandshake:
    73  		encLevel = protocol.EncryptionHandshake
    74  		opener, err := u.cs.GetHandshakeOpener()
    75  		if err != nil {
    76  			return nil, err
    77  		}
    78  		extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
    79  		if err != nil {
    80  			return nil, err
    81  		}
    82  	case protocol.PacketType0RTT:
    83  		encLevel = protocol.Encryption0RTT
    84  		opener, err := u.cs.Get0RTTOpener()
    85  		if err != nil {
    86  			return nil, err
    87  		}
    88  		extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  	default:
    93  		return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
    94  	}
    95  
    96  	if len(decrypted) == 0 {
    97  		return nil, &qerr.TransportError{
    98  			ErrorCode:    qerr.ProtocolViolation,
    99  			ErrorMessage: "empty packet",
   100  		}
   101  	}
   102  
   103  	return &unpackedPacket{
   104  		hdr:             extHdr,
   105  		encryptionLevel: encLevel,
   106  		data:            decrypted,
   107  	}, nil
   108  }
   109  
   110  func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
   111  	opener, err := u.cs.Get1RTTOpener()
   112  	if err != nil {
   113  		return 0, 0, 0, nil, err
   114  	}
   115  	pn, pnLen, kp, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data)
   116  	if err != nil {
   117  		return 0, 0, 0, nil, err
   118  	}
   119  	if len(decrypted) == 0 {
   120  		return 0, 0, 0, nil, &qerr.TransportError{
   121  			ErrorCode:    qerr.ProtocolViolation,
   122  			ErrorMessage: "empty packet",
   123  		}
   124  	}
   125  	return pn, pnLen, kp, decrypted, nil
   126  }
   127  
   128  func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, []byte, error) {
   129  	extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
   130  	// If the reserved bits are set incorrectly, we still need to continue unpacking.
   131  	// This avoids a timing side-channel, which otherwise might allow an attacker
   132  	// to gain information about the header encryption.
   133  	if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
   134  		return nil, nil, parseErr
   135  	}
   136  	extHdrLen := extHdr.ParsedLen()
   137  	extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen)
   138  	decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
   139  	if err != nil {
   140  		return nil, nil, err
   141  	}
   142  	if parseErr != nil {
   143  		return nil, nil, parseErr
   144  	}
   145  	return extHdr, decrypted, nil
   146  }
   147  
   148  func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
   149  	l, pn, pnLen, kp, parseErr := u.unpackShortHeader(opener, data)
   150  	// If the reserved bits are set incorrectly, we still need to continue unpacking.
   151  	// This avoids a timing side-channel, which otherwise might allow an attacker
   152  	// to gain information about the header encryption.
   153  	if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
   154  		return 0, 0, 0, nil, &headerParseError{parseErr}
   155  	}
   156  	pn = opener.DecodePacketNumber(pn, pnLen)
   157  	decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, pn, kp, data[:l])
   158  	if err != nil {
   159  		return 0, 0, 0, nil, err
   160  	}
   161  	return pn, pnLen, kp, decrypted, parseErr
   162  }
   163  
   164  func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int, protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, error) {
   165  	hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen
   166  	if len(data) < hdrLen+4+16 {
   167  		return 0, 0, 0, 0, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
   168  	}
   169  	origPNBytes := make([]byte, 4)
   170  	copy(origPNBytes, data[hdrLen:hdrLen+4])
   171  	// 2. decrypt the header, assuming a 4 byte packet number
   172  	hd.DecryptHeader(
   173  		data[hdrLen+4:hdrLen+4+16],
   174  		&data[0],
   175  		data[hdrLen:hdrLen+4],
   176  	)
   177  	// 3. parse the header (and learn the actual length of the packet number)
   178  	l, pn, pnLen, kp, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen)
   179  	if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
   180  		return l, pn, pnLen, kp, parseErr
   181  	}
   182  	// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
   183  	if pnLen != protocol.PacketNumberLen4 {
   184  		copy(data[hdrLen+int(pnLen):hdrLen+4], origPNBytes[int(pnLen):])
   185  	}
   186  	return l, pn, pnLen, kp, parseErr
   187  }
   188  
   189  // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
   190  func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
   191  	extHdr, err := unpackLongHeader(hd, hdr, data, v)
   192  	if err != nil && err != wire.ErrInvalidReservedBits {
   193  		return nil, &headerParseError{err: err}
   194  	}
   195  	return extHdr, err
   196  }
   197  
   198  func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
   199  	r := bytes.NewReader(data)
   200  
   201  	hdrLen := hdr.ParsedLen()
   202  	if protocol.ByteCount(len(data)) < hdrLen+4+16 {
   203  		//nolint:stylecheck
   204  		return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen)
   205  	}
   206  	// The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it.
   207  	// 1. save a copy of the 4 bytes
   208  	origPNBytes := make([]byte, 4)
   209  	copy(origPNBytes, data[hdrLen:hdrLen+4])
   210  	// 2. decrypt the header, assuming a 4 byte packet number
   211  	hd.DecryptHeader(
   212  		data[hdrLen+4:hdrLen+4+16],
   213  		&data[0],
   214  		data[hdrLen:hdrLen+4],
   215  	)
   216  	// 3. parse the header (and learn the actual length of the packet number)
   217  	extHdr, parseErr := hdr.ParseExtended(r, v)
   218  	if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
   219  		return nil, parseErr
   220  	}
   221  	// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
   222  	if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
   223  		copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
   224  	}
   225  	return extHdr, parseErr
   226  }