github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/internal/wire/short_header_test.go (about)

     1  package wire
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"log"
     7  	"os"
     8  	"testing"
     9  
    10  	"github.com/apernet/quic-go/internal/protocol"
    11  	"github.com/apernet/quic-go/internal/utils"
    12  
    13  	. "github.com/onsi/ginkgo/v2"
    14  	. "github.com/onsi/gomega"
    15  )
    16  
    17  var _ = Describe("Short Header", func() {
    18  	Context("Parsing", func() {
    19  		It("parses", func() {
    20  			data := []byte{
    21  				0b01000110,
    22  				0xde, 0xad, 0xbe, 0xef,
    23  				0x13, 0x37, 0x99,
    24  			}
    25  			l, pn, pnLen, kp, err := ParseShortHeader(data, 4)
    26  			Expect(err).ToNot(HaveOccurred())
    27  			Expect(l).To(Equal(len(data)))
    28  			Expect(kp).To(Equal(protocol.KeyPhaseOne))
    29  			Expect(pn).To(Equal(protocol.PacketNumber(0x133799)))
    30  			Expect(pnLen).To(Equal(protocol.PacketNumberLen3))
    31  		})
    32  
    33  		It("errors when the QUIC bit is not set", func() {
    34  			data := []byte{
    35  				0b00000101,
    36  				0xde, 0xad, 0xbe, 0xef,
    37  				0x13, 0x37,
    38  			}
    39  			_, _, _, _, err := ParseShortHeader(data, 4)
    40  			Expect(err).To(MatchError("not a QUIC packet"))
    41  		})
    42  
    43  		It("errors, but returns the header, when the reserved bits are set", func() {
    44  			data := []byte{
    45  				0b01010101,
    46  				0xde, 0xad, 0xbe, 0xef,
    47  				0x13, 0x37,
    48  			}
    49  			_, pn, _, _, err := ParseShortHeader(data, 4)
    50  			Expect(err).To(MatchError(ErrInvalidReservedBits))
    51  			Expect(pn).To(Equal(protocol.PacketNumber(0x1337)))
    52  		})
    53  
    54  		It("errors when passed a long header packet", func() {
    55  			_, _, _, _, err := ParseShortHeader([]byte{0x80}, 4)
    56  			Expect(err).To(MatchError("not a short header packet"))
    57  		})
    58  
    59  		It("errors on EOF", func() {
    60  			data := []byte{
    61  				0b01000110,
    62  				0xde, 0xad, 0xbe, 0xef,
    63  				0x13, 0x37, 0x99,
    64  			}
    65  			_, _, _, _, err := ParseShortHeader(data, 4)
    66  			Expect(err).ToNot(HaveOccurred())
    67  			for i := range data {
    68  				_, _, _, _, err := ParseShortHeader(data[:i], 4)
    69  				Expect(err).To(MatchError(io.EOF))
    70  			}
    71  		})
    72  	})
    73  
    74  	It("determines the length", func() {
    75  		Expect(ShortHeaderLen(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), protocol.PacketNumberLen3)).To(BeEquivalentTo(8))
    76  		Expect(ShortHeaderLen(protocol.ParseConnectionID([]byte{}), protocol.PacketNumberLen1)).To(BeEquivalentTo(2))
    77  	})
    78  
    79  	Context("writing", func() {
    80  		It("writes a short header packet", func() {
    81  			connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    82  			b, err := AppendShortHeader(nil, connID, 1337, 4, protocol.KeyPhaseOne)
    83  			Expect(err).ToNot(HaveOccurred())
    84  			l, pn, pnLen, kp, err := ParseShortHeader(b, 4)
    85  			Expect(err).ToNot(HaveOccurred())
    86  			Expect(pn).To(Equal(protocol.PacketNumber(1337)))
    87  			Expect(pnLen).To(Equal(protocol.PacketNumberLen4))
    88  			Expect(kp).To(Equal(protocol.KeyPhaseOne))
    89  			Expect(l).To(Equal(len(b)))
    90  		})
    91  	})
    92  
    93  	Context("logging", func() {
    94  		var (
    95  			buf    *bytes.Buffer
    96  			logger utils.Logger
    97  		)
    98  
    99  		BeforeEach(func() {
   100  			buf = &bytes.Buffer{}
   101  			logger = utils.DefaultLogger
   102  			logger.SetLogLevel(utils.LogLevelDebug)
   103  			log.SetOutput(buf)
   104  		})
   105  
   106  		AfterEach(func() {
   107  			log.SetOutput(os.Stdout)
   108  		})
   109  
   110  		It("logs Short Headers containing a connection ID", func() {
   111  			connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})
   112  			LogShortHeader(logger, connID, 1337, protocol.PacketNumberLen4, protocol.KeyPhaseOne)
   113  			Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}"))
   114  		})
   115  	})
   116  })
   117  
   118  func BenchmarkWriteShortHeader(b *testing.B) {
   119  	b.ReportAllocs()
   120  	buf := make([]byte, 100)
   121  	connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
   122  	for i := 0; i < b.N; i++ {
   123  		var err error
   124  		buf, err = AppendShortHeader(buf, connID, 1337, protocol.PacketNumberLen4, protocol.KeyPhaseOne)
   125  		if err != nil {
   126  			b.Fatalf("failed to write short header: %s", err)
   127  		}
   128  		buf = buf[:0]
   129  	}
   130  }