github.com/tumi8/quic-go@v0.37.4-tum/packet_handler_map_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"net"
     7  	"time"
     8  
     9  
    10  	"github.com/tumi8/quic-go/noninternal/protocol"
    11  	"github.com/tumi8/quic-go/noninternal/utils"
    12  
    13  	. "github.com/onsi/ginkgo/v2"
    14  	. "github.com/onsi/gomega"
    15  )
    16  
    17  var _ = Describe("Packet Handler Map", func() {
    18  	It("adds and gets", func() {
    19  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    20  		connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    21  		handler := NewMockPacketHandler(mockCtrl)
    22  		Expect(m.Add(connID, handler)).To(BeTrue())
    23  		h, ok := m.Get(connID)
    24  		Expect(ok).To(BeTrue())
    25  		Expect(h).To(Equal(handler))
    26  	})
    27  
    28  	It("refused to add duplicates", func() {
    29  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    30  		connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    31  		handler := NewMockPacketHandler(mockCtrl)
    32  		Expect(m.Add(connID, handler)).To(BeTrue())
    33  		Expect(m.Add(connID, handler)).To(BeFalse())
    34  	})
    35  
    36  	It("removes", func() {
    37  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    38  		connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    39  		handler := NewMockPacketHandler(mockCtrl)
    40  		Expect(m.Add(connID, handler)).To(BeTrue())
    41  		m.Remove(connID)
    42  		_, ok := m.Get(connID)
    43  		Expect(ok).To(BeFalse())
    44  		Expect(m.Add(connID, handler)).To(BeTrue())
    45  	})
    46  
    47  	It("retires", func() {
    48  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    49  		dur := scaleDuration(50 * time.Millisecond)
    50  		m.deleteRetiredConnsAfter = dur
    51  		connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    52  		handler := NewMockPacketHandler(mockCtrl)
    53  		Expect(m.Add(connID, handler)).To(BeTrue())
    54  		m.Retire(connID)
    55  		_, ok := m.Get(connID)
    56  		Expect(ok).To(BeTrue())
    57  		time.Sleep(dur)
    58  		Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
    59  	})
    60  
    61  	It("adds newly to-be-constructed handlers", func() {
    62  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    63  		var called bool
    64  		connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
    65  		connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
    66  		Expect(m.AddWithConnID(connID1, connID2, func() (packetHandler, bool) {
    67  			called = true
    68  			return NewMockPacketHandler(mockCtrl), true
    69  		})).To(BeTrue())
    70  		Expect(called).To(BeTrue())
    71  		Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() (packetHandler, bool) {
    72  			Fail("didn't expect the constructor to be executed")
    73  			return nil, false
    74  		})).To(BeFalse())
    75  	})
    76  
    77  	It("adds, gets and removes reset tokens", func() {
    78  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    79  		token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
    80  		handler := NewMockPacketHandler(mockCtrl)
    81  		m.AddResetToken(token, handler)
    82  		h, ok := m.GetByResetToken(token)
    83  		Expect(ok).To(BeTrue())
    84  		Expect(h).To(Equal(h))
    85  		m.RemoveResetToken(token)
    86  		_, ok = m.GetByResetToken(token)
    87  		Expect(ok).To(BeFalse())
    88  	})
    89  
    90  	It("generates stateless reset token, if no key is set", func() {
    91  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
    92  		b := make([]byte, 8)
    93  		rand.Read(b)
    94  		connID := protocol.ParseConnectionID(b)
    95  		token := m.GetStatelessResetToken(connID)
    96  		for i := 0; i < 1000; i++ {
    97  			to := m.GetStatelessResetToken(connID)
    98  			Expect(to).ToNot(Equal(token))
    99  			token = to
   100  		}
   101  	})
   102  
   103  	It("generates stateless reset token, if a key is set", func() {
   104  		var key StatelessResetKey
   105  		rand.Read(key[:])
   106  		m := newPacketHandlerMap(&key, nil, utils.DefaultLogger)
   107  		b := make([]byte, 8)
   108  		rand.Read(b)
   109  		connID := protocol.ParseConnectionID(b)
   110  		token := m.GetStatelessResetToken(connID)
   111  		Expect(token).ToNot(BeZero())
   112  		Expect(m.GetStatelessResetToken(connID)).To(Equal(token))
   113  		// generate a new connection ID
   114  		rand.Read(b)
   115  		connID2 := protocol.ParseConnectionID(b)
   116  		Expect(m.GetStatelessResetToken(connID2)).ToNot(Equal(token))
   117  	})
   118  
   119  	It("replaces locally closed connections", func() {
   120  		var closePackets []closePacket
   121  		m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
   122  		dur := scaleDuration(50 * time.Millisecond)
   123  		m.deleteRetiredConnsAfter = dur
   124  
   125  		handler := NewMockPacketHandler(mockCtrl)
   126  		connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
   127  		Expect(m.Add(connID, handler)).To(BeTrue())
   128  		m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, []byte("foobar"))
   129  		h, ok := m.Get(connID)
   130  		Expect(ok).To(BeTrue())
   131  		Expect(h).ToNot(Equal(handler))
   132  		addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
   133  		h.handlePacket(receivedPacket{remoteAddr: addr})
   134  		Expect(closePackets).To(HaveLen(1))
   135  		Expect(closePackets[0].addr).To(Equal(addr))
   136  		Expect(closePackets[0].payload).To(Equal([]byte("foobar")))
   137  
   138  		time.Sleep(dur)
   139  		Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
   140  	})
   141  
   142  	It("replaces remote closed connections", func() {
   143  		var closePackets []closePacket
   144  		m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
   145  		dur := scaleDuration(50 * time.Millisecond)
   146  		m.deleteRetiredConnsAfter = dur
   147  
   148  		handler := NewMockPacketHandler(mockCtrl)
   149  		connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
   150  		Expect(m.Add(connID, handler)).To(BeTrue())
   151  		m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, nil)
   152  		h, ok := m.Get(connID)
   153  		Expect(ok).To(BeTrue())
   154  		Expect(h).ToNot(Equal(handler))
   155  		addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
   156  		h.handlePacket(receivedPacket{remoteAddr: addr})
   157  		Expect(closePackets).To(BeEmpty())
   158  
   159  		time.Sleep(dur)
   160  		Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
   161  	})
   162  
   163  	It("closes the server", func() {
   164  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
   165  		for i := 0; i < 10; i++ {
   166  			conn := NewMockPacketHandler(mockCtrl)
   167  			if i%2 == 0 {
   168  				conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
   169  			} else {
   170  				conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
   171  				conn.EXPECT().shutdown()
   172  			}
   173  			b := make([]byte, 12)
   174  			rand.Read(b)
   175  			m.Add(protocol.ParseConnectionID(b), conn)
   176  		}
   177  		m.CloseServer()
   178  	})
   179  
   180  	It("closes", func() {
   181  		m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
   182  		testErr := errors.New("shutdown")
   183  		for i := 0; i < 10; i++ {
   184  			conn := NewMockPacketHandler(mockCtrl)
   185  			conn.EXPECT().destroy(testErr)
   186  			b := make([]byte, 12)
   187  			rand.Read(b)
   188  			m.Add(protocol.ParseConnectionID(b), conn)
   189  		}
   190  		m.Close(testErr)
   191  		// check that Close can be called multiple times
   192  		m.Close(errors.New("close"))
   193  	})
   194  })