github.com/MerlinKodo/quic-go@v0.39.2/packet_handler_map_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"net"
     7  	"time"
     8  
     9  	"github.com/MerlinKodo/quic-go/internal/protocol"
    10  	"github.com/MerlinKodo/quic-go/internal/utils"
    11  
    12  	. "github.com/onsi/ginkgo/v2"
    13  	. "github.com/onsi/gomega"
    14  )
    15  
    16  var _ = Describe("Packet Handler Map", func() {
    17  	It("adds and gets", func() {
    18  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    19  		connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    20  		handler := NewMockPacketHandler(mockCtrl)
    21  		Expect(m.Add(connID, handler)).To(BeTrue())
    22  		h, ok := m.Get(connID)
    23  		Expect(ok).To(BeTrue())
    24  		Expect(h).To(Equal(handler))
    25  	})
    26  
    27  	It("refused to add duplicates", func() {
    28  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    29  		connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    30  		handler := NewMockPacketHandler(mockCtrl)
    31  		Expect(m.Add(connID, handler)).To(BeTrue())
    32  		Expect(m.Add(connID, handler)).To(BeFalse())
    33  	})
    34  
    35  	It("removes", func() {
    36  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    37  		connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    38  		handler := NewMockPacketHandler(mockCtrl)
    39  		Expect(m.Add(connID, handler)).To(BeTrue())
    40  		m.Remove(connID)
    41  		_, ok := m.Get(connID)
    42  		Expect(ok).To(BeFalse())
    43  		Expect(m.Add(connID, handler)).To(BeTrue())
    44  	})
    45  
    46  	It("retires", func() {
    47  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    48  		dur := scaleDuration(50 * time.Millisecond)
    49  		m.deleteRetiredConnsAfter = dur
    50  		connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    51  		handler := NewMockPacketHandler(mockCtrl)
    52  		Expect(m.Add(connID, handler)).To(BeTrue())
    53  		m.Retire(connID)
    54  		_, ok := m.Get(connID)
    55  		Expect(ok).To(BeTrue())
    56  		time.Sleep(dur)
    57  		Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
    58  	})
    59  
    60  	It("adds newly to-be-constructed handlers", func() {
    61  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    62  		var called bool
    63  		connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    64  		connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
    65  		Expect(m.AddWithConnID(connID1, connID2, func() (packetHandler, bool) {
    66  			called = true
    67  			return NewMockPacketHandler(mockCtrl), true
    68  		})).To(BeTrue())
    69  		Expect(called).To(BeTrue())
    70  		Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() (packetHandler, bool) {
    71  			Fail("didn't expect the constructor to be executed")
    72  			return nil, false
    73  		})).To(BeFalse())
    74  	})
    75  
    76  	It("adds, gets and removes reset tokens", func() {
    77  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    78  		token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
    79  		handler := NewMockPacketHandler(mockCtrl)
    80  		m.AddResetToken(token, handler)
    81  		h, ok := m.GetByResetToken(token)
    82  		Expect(ok).To(BeTrue())
    83  		Expect(h).To(Equal(h))
    84  		m.RemoveResetToken(token)
    85  		_, ok = m.GetByResetToken(token)
    86  		Expect(ok).To(BeFalse())
    87  	})
    88  
    89  	It("generates stateless reset token, if no key is set", func() {
    90  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    91  		b := make([]byte, 8)
    92  		rand.Read(b)
    93  		connID := protocol.ParseConnectionID(b)
    94  		token := m.GetStatelessResetToken(connID)
    95  		for i := 0; i < 1000; i++ {
    96  			to := m.GetStatelessResetToken(connID)
    97  			Expect(to).ToNot(Equal(token))
    98  			token = to
    99  		}
   100  	})
   101  
   102  	It("generates stateless reset token, if a key is set", func() {
   103  		var key StatelessResetKey
   104  		rand.Read(key[:])
   105  		m := newPacketHandlerMap(&key, nil, utils.DefaultLogger)
   106  		b := make([]byte, 8)
   107  		rand.Read(b)
   108  		connID := protocol.ParseConnectionID(b)
   109  		token := m.GetStatelessResetToken(connID)
   110  		Expect(token).ToNot(BeZero())
   111  		Expect(m.GetStatelessResetToken(connID)).To(Equal(token))
   112  		// generate a new connection ID
   113  		rand.Read(b)
   114  		connID2 := protocol.ParseConnectionID(b)
   115  		Expect(m.GetStatelessResetToken(connID2)).ToNot(Equal(token))
   116  	})
   117  
   118  	It("replaces locally closed connections", func() {
   119  		var closePackets []closePacket
   120  		m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
   121  		dur := scaleDuration(50 * time.Millisecond)
   122  		m.deleteRetiredConnsAfter = dur
   123  
   124  		handler := NewMockPacketHandler(mockCtrl)
   125  		connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
   126  		Expect(m.Add(connID, handler)).To(BeTrue())
   127  		m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, []byte("foobar"))
   128  		h, ok := m.Get(connID)
   129  		Expect(ok).To(BeTrue())
   130  		Expect(h).ToNot(Equal(handler))
   131  		addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
   132  		h.handlePacket(receivedPacket{remoteAddr: addr})
   133  		Expect(closePackets).To(HaveLen(1))
   134  		Expect(closePackets[0].addr).To(Equal(addr))
   135  		Expect(closePackets[0].payload).To(Equal([]byte("foobar")))
   136  
   137  		time.Sleep(dur)
   138  		Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
   139  	})
   140  
   141  	It("replaces remote closed connections", func() {
   142  		var closePackets []closePacket
   143  		m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
   144  		dur := scaleDuration(50 * time.Millisecond)
   145  		m.deleteRetiredConnsAfter = dur
   146  
   147  		handler := NewMockPacketHandler(mockCtrl)
   148  		connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
   149  		Expect(m.Add(connID, handler)).To(BeTrue())
   150  		m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, nil)
   151  		h, ok := m.Get(connID)
   152  		Expect(ok).To(BeTrue())
   153  		Expect(h).ToNot(Equal(handler))
   154  		addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
   155  		h.handlePacket(receivedPacket{remoteAddr: addr})
   156  		Expect(closePackets).To(BeEmpty())
   157  
   158  		time.Sleep(dur)
   159  		Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
   160  	})
   161  
   162  	It("closes the server", func() {
   163  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
   164  		for i := 0; i < 10; i++ {
   165  			conn := NewMockPacketHandler(mockCtrl)
   166  			if i%2 == 0 {
   167  				conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
   168  			} else {
   169  				conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
   170  				conn.EXPECT().shutdown()
   171  			}
   172  			b := make([]byte, 12)
   173  			rand.Read(b)
   174  			m.Add(protocol.ParseConnectionID(b), conn)
   175  		}
   176  		m.CloseServer()
   177  	})
   178  
   179  	It("closes", func() {
   180  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
   181  		testErr := errors.New("shutdown")
   182  		for i := 0; i < 10; i++ {
   183  			conn := NewMockPacketHandler(mockCtrl)
   184  			conn.EXPECT().destroy(testErr)
   185  			b := make([]byte, 12)
   186  			rand.Read(b)
   187  			m.Add(protocol.ParseConnectionID(b), conn)
   188  		}
   189  		m.Close(testErr)
   190  		// check that Close can be called multiple times
   191  		m.Close(errors.New("close"))
   192  	})
   193  })