github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/packet_unpacker_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"errors"
     5  	"time"
     6  
     7  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/handshake"
     8  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/mocks"
     9  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    10  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr"
    11  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/wire"
    12  
    13  	. "github.com/onsi/ginkgo/v2"
    14  	. "github.com/onsi/gomega"
    15  	"go.uber.org/mock/gomock"
    16  )
    17  
    18  var _ = Describe("Packet Unpacker", func() {
    19  	var (
    20  		unpacker *packetUnpacker
    21  		cs       *mocks.MockCryptoSetup
    22  		connID   = protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
    23  		payload  = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
    24  	)
    25  
    26  	getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) {
    27  		b, err := extHdr.Append(nil, protocol.Version1)
    28  		Expect(err).ToNot(HaveOccurred())
    29  		ExpectWithOffset(1, err).ToNot(HaveOccurred())
    30  		hdrLen := len(b)
    31  		if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) {
    32  			b = append(b, make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen))...)
    33  		}
    34  		hdr, _, _, err := wire.ParsePacket(b)
    35  		ExpectWithOffset(1, err).ToNot(HaveOccurred())
    36  		return hdr, b[:hdrLen]
    37  	}
    38  
    39  	getShortHeader := func(connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) []byte {
    40  		b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, kp)
    41  		Expect(err).ToNot(HaveOccurred())
    42  		return b
    43  	}
    44  
    45  	BeforeEach(func() {
    46  		cs = mocks.NewMockCryptoSetup(mockCtrl)
    47  		unpacker = newPacketUnpacker(cs, 4)
    48  	})
    49  
    50  	It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() {
    51  		extHdr := &wire.ExtendedHeader{
    52  			Header: wire.Header{
    53  				Type:             protocol.PacketTypeHandshake,
    54  				DestConnectionID: connID,
    55  				Version:          protocol.Version1,
    56  			},
    57  			PacketNumber:    1337,
    58  			PacketNumberLen: protocol.PacketNumberLen2,
    59  		}
    60  		hdr, hdrRaw := getLongHeader(extHdr)
    61  		data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
    62  		opener := mocks.NewMockLongHeaderOpener(mockCtrl)
    63  		cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
    64  		_, err := unpacker.UnpackLongHeader(hdr, time.Now(), data, protocol.Version1)
    65  		Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
    66  		var headerErr *headerParseError
    67  		Expect(errors.As(err, &headerErr)).To(BeTrue())
    68  		Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19"))
    69  	})
    70  
    71  	It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() {
    72  		b, err := wire.AppendShortHeader(nil, connID, 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
    73  		Expect(err).ToNot(HaveOccurred())
    74  		data := append(b, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
    75  		opener := mocks.NewMockShortHeaderOpener(mockCtrl)
    76  		cs.EXPECT().Get1RTTOpener().Return(opener, nil)
    77  		_, _, _, _, err = unpacker.UnpackShortHeader(time.Now(), data)
    78  		Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
    79  		Expect(err).To(MatchError("packet too small, expected at least 20 bytes after the header, got 19"))
    80  	})
    81  
    82  	It("opens Initial packets", func() {
    83  		extHdr := &wire.ExtendedHeader{
    84  			Header: wire.Header{
    85  				Type:             protocol.PacketTypeInitial,
    86  				Length:           3 + 6, // packet number len + payload
    87  				DestConnectionID: connID,
    88  				Version:          protocol.Version1,
    89  			},
    90  			PacketNumber:    2,
    91  			PacketNumberLen: 3,
    92  		}
    93  		hdr, hdrRaw := getLongHeader(extHdr)
    94  		opener := mocks.NewMockLongHeaderOpener(mockCtrl)
    95  		gomock.InOrder(
    96  			cs.EXPECT().GetInitialOpener().Return(opener, nil),
    97  			opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
    98  			opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)),
    99  			opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil),
   100  		)
   101  		packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
   102  		Expect(err).ToNot(HaveOccurred())
   103  		Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
   104  		Expect(packet.data).To(Equal([]byte("decrypted")))
   105  	})
   106  
   107  	It("opens 0-RTT packets", func() {
   108  		extHdr := &wire.ExtendedHeader{
   109  			Header: wire.Header{
   110  				Type:             protocol.PacketType0RTT,
   111  				Length:           3 + 6, // packet number len + payload
   112  				DestConnectionID: connID,
   113  				Version:          protocol.Version1,
   114  			},
   115  			PacketNumber:    20,
   116  			PacketNumberLen: 2,
   117  		}
   118  		hdr, hdrRaw := getLongHeader(extHdr)
   119  		opener := mocks.NewMockLongHeaderOpener(mockCtrl)
   120  		gomock.InOrder(
   121  			cs.EXPECT().Get0RTTOpener().Return(opener, nil),
   122  			opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
   123  			opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)),
   124  			opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil),
   125  		)
   126  		packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
   127  		Expect(err).ToNot(HaveOccurred())
   128  		Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT))
   129  		Expect(packet.data).To(Equal([]byte("decrypted")))
   130  	})
   131  
   132  	It("opens short header packets", func() {
   133  		hdrRaw := getShortHeader(connID, 99, protocol.PacketNumberLen4, protocol.KeyPhaseOne)
   134  		opener := mocks.NewMockShortHeaderOpener(mockCtrl)
   135  		now := time.Now()
   136  		gomock.InOrder(
   137  			cs.EXPECT().Get1RTTOpener().Return(opener, nil),
   138  			opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
   139  			opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)),
   140  			opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil),
   141  		)
   142  		pn, pnLen, kp, data, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...))
   143  		Expect(err).ToNot(HaveOccurred())
   144  		Expect(pn).To(Equal(protocol.PacketNumber(321)))
   145  		Expect(pnLen).To(Equal(protocol.PacketNumberLen4))
   146  		Expect(kp).To(Equal(protocol.KeyPhaseOne))
   147  		Expect(data).To(Equal([]byte("decrypted")))
   148  	})
   149  
   150  	It("returns the error when getting the opener fails", func() {
   151  		hdrRaw := getShortHeader(connID, 0x1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
   152  		cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable)
   153  		_, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
   154  		Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable))
   155  	})
   156  
   157  	It("errors on empty packets, for long header packets", func() {
   158  		extHdr := &wire.ExtendedHeader{
   159  			Header: wire.Header{
   160  				Type:             protocol.PacketTypeHandshake,
   161  				DestConnectionID: connID,
   162  				Version:          Version1,
   163  			},
   164  			KeyPhase:        protocol.KeyPhaseOne,
   165  			PacketNumberLen: protocol.PacketNumberLen4,
   166  		}
   167  		hdr, hdrRaw := getLongHeader(extHdr)
   168  		opener := mocks.NewMockLongHeaderOpener(mockCtrl)
   169  		gomock.InOrder(
   170  			cs.EXPECT().GetHandshakeOpener().Return(opener, nil),
   171  			opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
   172  			opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)),
   173  			opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte(""), nil),
   174  		)
   175  		_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
   176  		Expect(err).To(MatchError(&qerr.TransportError{
   177  			ErrorCode:    qerr.ProtocolViolation,
   178  			ErrorMessage: "empty packet",
   179  		}))
   180  	})
   181  
   182  	It("errors on empty packets, for short header packets", func() {
   183  		hdrRaw := getShortHeader(connID, 0x42, protocol.PacketNumberLen4, protocol.KeyPhaseOne)
   184  		opener := mocks.NewMockShortHeaderOpener(mockCtrl)
   185  		now := time.Now()
   186  		gomock.InOrder(
   187  			cs.EXPECT().Get1RTTOpener().Return(opener, nil),
   188  			opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
   189  			opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)),
   190  			opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte(""), nil),
   191  		)
   192  		_, _, _, _, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...))
   193  		Expect(err).To(MatchError(&qerr.TransportError{
   194  			ErrorCode:    qerr.ProtocolViolation,
   195  			ErrorMessage: "empty packet",
   196  		}))
   197  	})
   198  
   199  	It("returns the error when unpacking fails", func() {
   200  		extHdr := &wire.ExtendedHeader{
   201  			Header: wire.Header{
   202  				Type:             protocol.PacketTypeHandshake,
   203  				Length:           3, // packet number len
   204  				DestConnectionID: connID,
   205  				Version:          protocol.Version1,
   206  			},
   207  			PacketNumber:    2,
   208  			PacketNumberLen: 3,
   209  		}
   210  		hdr, hdrRaw := getLongHeader(extHdr)
   211  		opener := mocks.NewMockLongHeaderOpener(mockCtrl)
   212  		cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
   213  		opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
   214  		opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
   215  		unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded}
   216  		opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr)
   217  		_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
   218  		Expect(err).To(MatchError(unpackErr))
   219  	})
   220  
   221  	It("defends against the timing side-channel when the reserved bits are wrong, for long header packets", func() {
   222  		extHdr := &wire.ExtendedHeader{
   223  			Header: wire.Header{
   224  				Type:             protocol.PacketTypeHandshake,
   225  				DestConnectionID: connID,
   226  				Version:          protocol.Version1,
   227  			},
   228  			PacketNumber:    0x1337,
   229  			PacketNumberLen: 2,
   230  		}
   231  		hdr, hdrRaw := getLongHeader(extHdr)
   232  		hdrRaw[0] |= 0xc
   233  		opener := mocks.NewMockLongHeaderOpener(mockCtrl)
   234  		opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
   235  		cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
   236  		opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
   237  		opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
   238  		_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
   239  		Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
   240  	})
   241  
   242  	It("defends against the timing side-channel when the reserved bits are wrong, for short header packets", func() {
   243  		hdrRaw := getShortHeader(connID, 0x1337, protocol.PacketNumberLen2, protocol.KeyPhaseZero)
   244  		hdrRaw[0] |= 0x18
   245  		opener := mocks.NewMockShortHeaderOpener(mockCtrl)
   246  		opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
   247  		cs.EXPECT().Get1RTTOpener().Return(opener, nil)
   248  		opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
   249  		opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
   250  		_, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
   251  		Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
   252  	})
   253  
   254  	It("returns the decryption error, when unpacking a packet with wrong reserved bits fails, for long headers", func() {
   255  		extHdr := &wire.ExtendedHeader{
   256  			Header: wire.Header{
   257  				Type:             protocol.PacketTypeHandshake,
   258  				DestConnectionID: connID,
   259  				Version:          protocol.Version1,
   260  			},
   261  			PacketNumber:    0x1337,
   262  			PacketNumberLen: 2,
   263  		}
   264  		hdr, hdrRaw := getLongHeader(extHdr)
   265  		hdrRaw[0] |= 0x18
   266  		opener := mocks.NewMockLongHeaderOpener(mockCtrl)
   267  		opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
   268  		cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
   269  		opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
   270  		opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
   271  		_, err := unpacker.UnpackLongHeader(hdr, time.Now(), append(hdrRaw, payload...), protocol.Version1)
   272  		Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
   273  	})
   274  
   275  	It("returns the decryption error, when unpacking a packet with wrong reserved bits fails, for short headers", func() {
   276  		hdrRaw := getShortHeader(connID, 0x1337, protocol.PacketNumberLen2, protocol.KeyPhaseZero)
   277  		hdrRaw[0] |= 0x18
   278  		opener := mocks.NewMockShortHeaderOpener(mockCtrl)
   279  		opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
   280  		cs.EXPECT().Get1RTTOpener().Return(opener, nil)
   281  		opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
   282  		opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
   283  		_, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
   284  		Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
   285  	})
   286  
   287  	It("decrypts the header", func() {
   288  		extHdr := &wire.ExtendedHeader{
   289  			Header: wire.Header{
   290  				Type:             protocol.PacketTypeHandshake,
   291  				Length:           3, // packet number len
   292  				DestConnectionID: connID,
   293  				Version:          protocol.Version1,
   294  			},
   295  			PacketNumber:    0x1337,
   296  			PacketNumberLen: 2,
   297  		}
   298  		hdr, hdrRaw := getLongHeader(extHdr)
   299  		origHdrRaw := append([]byte{}, hdrRaw...) // save a copy of the header
   300  		firstHdrByte := hdrRaw[0]
   301  		hdrRaw[0] ^= 0xff             // invert the first byte
   302  		hdrRaw[len(hdrRaw)-2] ^= 0xff // invert the packet number
   303  		hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number
   304  		Expect(hdrRaw[0]).ToNot(Equal(firstHdrByte))
   305  		opener := mocks.NewMockLongHeaderOpener(mockCtrl)
   306  		cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
   307  		gomock.InOrder(
   308  			// we're using a 2 byte packet number, so the sample starts at the 3rd payload byte
   309  			opener.EXPECT().DecryptHeader(
   310  				[]byte{3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18},
   311  				&hdrRaw[0],
   312  				append(hdrRaw[len(hdrRaw)-2:], []byte{1, 2}...)).Do(func(_ []byte, firstByte *byte, pnBytes []byte) {
   313  				*firstByte ^= 0xff // invert the first byte back
   314  				for i := range pnBytes {
   315  					pnBytes[i] ^= 0xff // invert the packet number bytes
   316  				}
   317  			}),
   318  			opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2).Return(protocol.PacketNumber(0x7331)),
   319  			opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x7331), origHdrRaw).Return([]byte{0}, nil),
   320  		)
   321  		data := hdrRaw
   322  		for i := 1; i <= 100; i++ {
   323  			data = append(data, uint8(i))
   324  		}
   325  		packet, err := unpacker.UnpackLongHeader(hdr, time.Now(), data, protocol.Version1)
   326  		Expect(err).ToNot(HaveOccurred())
   327  		Expect(packet.hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x7331)))
   328  	})
   329  })