github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/packet_handler_map_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"net"
     7  	"time"
     8  
     9  	"github.com/apernet/quic-go/internal/protocol"
    10  	"github.com/apernet/quic-go/internal/utils"
    11  
    12  	. "github.com/onsi/ginkgo/v2"
    13  	. "github.com/onsi/gomega"
    14  )
    15  
    16  var _ = Describe("Packet Handler Map", func() {
    17  	It("adds and gets", func() {
    18  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    19  		connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    20  		handler := NewMockPacketHandler(mockCtrl)
    21  		Expect(m.Add(connID, handler)).To(BeTrue())
    22  		h, ok := m.Get(connID)
    23  		Expect(ok).To(BeTrue())
    24  		Expect(h).To(Equal(handler))
    25  	})
    26  
    27  	It("refused to add duplicates", func() {
    28  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    29  		connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    30  		handler := NewMockPacketHandler(mockCtrl)
    31  		Expect(m.Add(connID, handler)).To(BeTrue())
    32  		Expect(m.Add(connID, handler)).To(BeFalse())
    33  	})
    34  
    35  	It("removes", func() {
    36  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    37  		connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    38  		handler := NewMockPacketHandler(mockCtrl)
    39  		Expect(m.Add(connID, handler)).To(BeTrue())
    40  		m.Remove(connID)
    41  		_, ok := m.Get(connID)
    42  		Expect(ok).To(BeFalse())
    43  		Expect(m.Add(connID, handler)).To(BeTrue())
    44  	})
    45  
    46  	It("retires", func() {
    47  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    48  		dur := scaleDuration(50 * time.Millisecond)
    49  		m.deleteRetiredConnsAfter = dur
    50  		connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    51  		handler := NewMockPacketHandler(mockCtrl)
    52  		Expect(m.Add(connID, handler)).To(BeTrue())
    53  		m.Retire(connID)
    54  		_, ok := m.Get(connID)
    55  		Expect(ok).To(BeTrue())
    56  		time.Sleep(dur)
    57  		Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
    58  	})
    59  
    60  	It("adds newly to-be-constructed handlers", func() {
    61  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    62  		connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    63  		connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
    64  		h := NewMockPacketHandler(mockCtrl)
    65  		Expect(m.AddWithConnID(connID1, connID2, h)).To(BeTrue())
    66  		// collision of the destination connection ID, this handler should not be added
    67  		Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), nil)).To(BeFalse())
    68  	})
    69  
    70  	It("adds, gets and removes reset tokens", func() {
    71  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    72  		token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
    73  		handler := NewMockPacketHandler(mockCtrl)
    74  		m.AddResetToken(token, handler)
    75  		h, ok := m.GetByResetToken(token)
    76  		Expect(ok).To(BeTrue())
    77  		Expect(h).To(Equal(h))
    78  		m.RemoveResetToken(token)
    79  		_, ok = m.GetByResetToken(token)
    80  		Expect(ok).To(BeFalse())
    81  	})
    82  
    83  	It("generates stateless reset token, if no key is set", func() {
    84  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    85  		b := make([]byte, 8)
    86  		rand.Read(b)
    87  		connID := protocol.ParseConnectionID(b)
    88  		token := m.GetStatelessResetToken(connID)
    89  		for i := 0; i < 1000; i++ {
    90  			to := m.GetStatelessResetToken(connID)
    91  			Expect(to).ToNot(Equal(token))
    92  			token = to
    93  		}
    94  	})
    95  
    96  	It("generates stateless reset token, if a key is set", func() {
    97  		var key StatelessResetKey
    98  		rand.Read(key[:])
    99  		m := newPacketHandlerMap(&key, nil, utils.DefaultLogger)
   100  		b := make([]byte, 8)
   101  		rand.Read(b)
   102  		connID := protocol.ParseConnectionID(b)
   103  		token := m.GetStatelessResetToken(connID)
   104  		Expect(token).ToNot(BeZero())
   105  		Expect(m.GetStatelessResetToken(connID)).To(Equal(token))
   106  		// generate a new connection ID
   107  		rand.Read(b)
   108  		connID2 := protocol.ParseConnectionID(b)
   109  		Expect(m.GetStatelessResetToken(connID2)).ToNot(Equal(token))
   110  	})
   111  
   112  	It("replaces locally closed connections", func() {
   113  		var closePackets []closePacket
   114  		m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
   115  		dur := scaleDuration(50 * time.Millisecond)
   116  		m.deleteRetiredConnsAfter = dur
   117  
   118  		handler := NewMockPacketHandler(mockCtrl)
   119  		connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
   120  		Expect(m.Add(connID, handler)).To(BeTrue())
   121  		m.ReplaceWithClosed([]protocol.ConnectionID{connID}, []byte("foobar"))
   122  		h, ok := m.Get(connID)
   123  		Expect(ok).To(BeTrue())
   124  		Expect(h).ToNot(Equal(handler))
   125  		addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
   126  		h.handlePacket(receivedPacket{remoteAddr: addr})
   127  		Expect(closePackets).To(HaveLen(1))
   128  		Expect(closePackets[0].addr).To(Equal(addr))
   129  		Expect(closePackets[0].payload).To(Equal([]byte("foobar")))
   130  
   131  		time.Sleep(dur)
   132  		Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
   133  	})
   134  
   135  	It("replaces remote closed connections", func() {
   136  		var closePackets []closePacket
   137  		m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
   138  		dur := scaleDuration(50 * time.Millisecond)
   139  		m.deleteRetiredConnsAfter = dur
   140  
   141  		handler := NewMockPacketHandler(mockCtrl)
   142  		connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
   143  		Expect(m.Add(connID, handler)).To(BeTrue())
   144  		m.ReplaceWithClosed([]protocol.ConnectionID{connID}, nil)
   145  		h, ok := m.Get(connID)
   146  		Expect(ok).To(BeTrue())
   147  		Expect(h).ToNot(Equal(handler))
   148  		addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
   149  		h.handlePacket(receivedPacket{remoteAddr: addr})
   150  		Expect(closePackets).To(BeEmpty())
   151  
   152  		time.Sleep(dur)
   153  		Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
   154  	})
   155  
   156  	It("closes", func() {
   157  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
   158  		testErr := errors.New("shutdown")
   159  		for i := 0; i < 10; i++ {
   160  			conn := NewMockPacketHandler(mockCtrl)
   161  			conn.EXPECT().destroy(testErr)
   162  			b := make([]byte, 12)
   163  			rand.Read(b)
   164  			m.Add(protocol.ParseConnectionID(b), conn)
   165  		}
   166  		m.Close(testErr)
   167  		// check that Close can be called multiple times
   168  		m.Close(errors.New("close"))
   169  	})
   170  })