github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/streams_map_incoming_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"math/rand"
     7  	"time"
     8  
     9  	"github.com/mikelsr/quic-go/internal/protocol"
    10  	"github.com/mikelsr/quic-go/internal/wire"
    11  
    12  	"github.com/golang/mock/gomock"
    13  	. "github.com/onsi/ginkgo/v2"
    14  	. "github.com/onsi/gomega"
    15  )
    16  
    17  type mockGenericStream struct {
    18  	num protocol.StreamNum
    19  
    20  	closed     bool
    21  	closeErr   error
    22  	sendWindow protocol.ByteCount
    23  }
    24  
    25  func (s *mockGenericStream) closeForShutdown(err error) {
    26  	s.closed = true
    27  	s.closeErr = err
    28  }
    29  
    30  func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) {
    31  	s.sendWindow = limit
    32  }
    33  
    34  var _ = Describe("Streams Map (incoming)", func() {
    35  	var (
    36  		m              *incomingStreamsMap[*mockGenericStream]
    37  		newItemCounter int
    38  		mockSender     *MockStreamSender
    39  		maxNumStreams  uint64
    40  	)
    41  	streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeUni}[rand.Intn(2)]
    42  
    43  	// check that the frame can be serialized and deserialized
    44  	checkFrameSerialization := func(f wire.Frame) {
    45  		b, err := f.Append(nil, protocol.Version1)
    46  		ExpectWithOffset(1, err).ToNot(HaveOccurred())
    47  		_, frame, err := wire.NewFrameParser(false).ParseNext(b, protocol.Encryption1RTT, protocol.Version1)
    48  		ExpectWithOffset(1, err).ToNot(HaveOccurred())
    49  		Expect(f).To(Equal(frame))
    50  	}
    51  
    52  	BeforeEach(func() { maxNumStreams = 5 })
    53  
    54  	JustBeforeEach(func() {
    55  		newItemCounter = 0
    56  		mockSender = NewMockStreamSender(mockCtrl)
    57  		m = newIncomingStreamsMap(
    58  			streamType,
    59  			func(num protocol.StreamNum) *mockGenericStream {
    60  				newItemCounter++
    61  				return &mockGenericStream{num: num}
    62  			},
    63  			maxNumStreams,
    64  			mockSender.queueControlFrame,
    65  		)
    66  	})
    67  
    68  	It("opens all streams up to the id on GetOrOpenStream", func() {
    69  		_, err := m.GetOrOpenStream(4)
    70  		Expect(err).ToNot(HaveOccurred())
    71  		Expect(newItemCounter).To(Equal(4))
    72  	})
    73  
    74  	It("starts opening streams at the right position", func() {
    75  		// like the test above, but with 2 calls to GetOrOpenStream
    76  		_, err := m.GetOrOpenStream(2)
    77  		Expect(err).ToNot(HaveOccurred())
    78  		Expect(newItemCounter).To(Equal(2))
    79  		_, err = m.GetOrOpenStream(5)
    80  		Expect(err).ToNot(HaveOccurred())
    81  		Expect(newItemCounter).To(Equal(5))
    82  	})
    83  
    84  	It("accepts streams in the right order", func() {
    85  		_, err := m.GetOrOpenStream(2) // open streams 1 and 2
    86  		Expect(err).ToNot(HaveOccurred())
    87  		str, err := m.AcceptStream(context.Background())
    88  		Expect(err).ToNot(HaveOccurred())
    89  		Expect(str.num).To(Equal(protocol.StreamNum(1)))
    90  		str, err = m.AcceptStream(context.Background())
    91  		Expect(err).ToNot(HaveOccurred())
    92  		Expect(str.num).To(Equal(protocol.StreamNum(2)))
    93  	})
    94  
    95  	It("allows opening the maximum stream ID", func() {
    96  		str, err := m.GetOrOpenStream(1)
    97  		Expect(err).ToNot(HaveOccurred())
    98  		Expect(str.num).To(Equal(protocol.StreamNum(1)))
    99  	})
   100  
   101  	It("errors when trying to get a stream ID higher than the maximum", func() {
   102  		_, err := m.GetOrOpenStream(6)
   103  		Expect(err).To(HaveOccurred())
   104  		Expect(err.(streamError).TestError()).To(MatchError("peer tried to open stream 6 (current limit: 5)"))
   105  	})
   106  
   107  	It("blocks AcceptStream until a new stream is available", func() {
   108  		strChan := make(chan *mockGenericStream)
   109  		go func() {
   110  			defer GinkgoRecover()
   111  			str, err := m.AcceptStream(context.Background())
   112  			Expect(err).ToNot(HaveOccurred())
   113  			strChan <- str
   114  		}()
   115  		Consistently(strChan).ShouldNot(Receive())
   116  		str, err := m.GetOrOpenStream(1)
   117  		Expect(err).ToNot(HaveOccurred())
   118  		Expect(str.num).To(Equal(protocol.StreamNum(1)))
   119  		var acceptedStr *mockGenericStream
   120  		Eventually(strChan).Should(Receive(&acceptedStr))
   121  		Expect(acceptedStr.num).To(Equal(protocol.StreamNum(1)))
   122  	})
   123  
   124  	It("unblocks AcceptStream when the context is canceled", func() {
   125  		ctx, cancel := context.WithCancel(context.Background())
   126  		done := make(chan struct{})
   127  		go func() {
   128  			defer GinkgoRecover()
   129  			_, err := m.AcceptStream(ctx)
   130  			Expect(err).To(MatchError("context canceled"))
   131  			close(done)
   132  		}()
   133  		Consistently(done).ShouldNot(BeClosed())
   134  		cancel()
   135  		Eventually(done).Should(BeClosed())
   136  	})
   137  
   138  	It("unblocks AcceptStream when it is closed", func() {
   139  		testErr := errors.New("test error")
   140  		done := make(chan struct{})
   141  		go func() {
   142  			defer GinkgoRecover()
   143  			_, err := m.AcceptStream(context.Background())
   144  			Expect(err).To(MatchError(testErr))
   145  			close(done)
   146  		}()
   147  		Consistently(done).ShouldNot(BeClosed())
   148  		m.CloseWithError(testErr)
   149  		Eventually(done).Should(BeClosed())
   150  	})
   151  
   152  	It("errors AcceptStream immediately if it is closed", func() {
   153  		testErr := errors.New("test error")
   154  		m.CloseWithError(testErr)
   155  		_, err := m.AcceptStream(context.Background())
   156  		Expect(err).To(MatchError(testErr))
   157  	})
   158  
   159  	It("closes all streams when CloseWithError is called", func() {
   160  		str1, err := m.GetOrOpenStream(1)
   161  		Expect(err).ToNot(HaveOccurred())
   162  		str2, err := m.GetOrOpenStream(3)
   163  		Expect(err).ToNot(HaveOccurred())
   164  		testErr := errors.New("test err")
   165  		m.CloseWithError(testErr)
   166  		Expect(str1.closed).To(BeTrue())
   167  		Expect(str1.closeErr).To(MatchError(testErr))
   168  		Expect(str2.closed).To(BeTrue())
   169  		Expect(str2.closeErr).To(MatchError(testErr))
   170  	})
   171  
   172  	It("deletes streams", func() {
   173  		mockSender.EXPECT().queueControlFrame(gomock.Any())
   174  		_, err := m.GetOrOpenStream(1)
   175  		Expect(err).ToNot(HaveOccurred())
   176  		str, err := m.AcceptStream(context.Background())
   177  		Expect(err).ToNot(HaveOccurred())
   178  		Expect(str.num).To(Equal(protocol.StreamNum(1)))
   179  		Expect(m.DeleteStream(1)).To(Succeed())
   180  		str, err = m.GetOrOpenStream(1)
   181  		Expect(err).ToNot(HaveOccurred())
   182  		Expect(str).To(BeNil())
   183  	})
   184  
   185  	It("waits until a stream is accepted before actually deleting it", func() {
   186  		_, err := m.GetOrOpenStream(2)
   187  		Expect(err).ToNot(HaveOccurred())
   188  		Expect(m.DeleteStream(2)).To(Succeed())
   189  		str, err := m.AcceptStream(context.Background())
   190  		Expect(err).ToNot(HaveOccurred())
   191  		Expect(str.num).To(Equal(protocol.StreamNum(1)))
   192  		// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
   193  		mockSender.EXPECT().queueControlFrame(gomock.Any())
   194  		str, err = m.AcceptStream(context.Background())
   195  		Expect(err).ToNot(HaveOccurred())
   196  		Expect(str.num).To(Equal(protocol.StreamNum(2)))
   197  	})
   198  
   199  	It("doesn't return a stream queued for deleting from GetOrOpenStream", func() {
   200  		str, err := m.GetOrOpenStream(1)
   201  		Expect(err).ToNot(HaveOccurred())
   202  		Expect(str).ToNot(BeNil())
   203  		Expect(m.DeleteStream(1)).To(Succeed())
   204  		str, err = m.GetOrOpenStream(1)
   205  		Expect(err).ToNot(HaveOccurred())
   206  		Expect(str).To(BeNil())
   207  		// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
   208  		mockSender.EXPECT().queueControlFrame(gomock.Any())
   209  		str, err = m.AcceptStream(context.Background())
   210  		Expect(err).ToNot(HaveOccurred())
   211  		Expect(str).ToNot(BeNil())
   212  	})
   213  
   214  	It("errors when deleting a non-existing stream", func() {
   215  		err := m.DeleteStream(1337)
   216  		Expect(err).To(HaveOccurred())
   217  		Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown incoming stream 1337"))
   218  	})
   219  
   220  	It("sends MAX_STREAMS frames when streams are deleted", func() {
   221  		// open a bunch of streams
   222  		_, err := m.GetOrOpenStream(5)
   223  		Expect(err).ToNot(HaveOccurred())
   224  		// accept all streams
   225  		for i := 0; i < 5; i++ {
   226  			_, err := m.AcceptStream(context.Background())
   227  			Expect(err).ToNot(HaveOccurred())
   228  		}
   229  		mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   230  			msf := f.(*wire.MaxStreamsFrame)
   231  			Expect(msf.Type).To(BeEquivalentTo(streamType))
   232  			Expect(msf.MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1)))
   233  			checkFrameSerialization(f)
   234  		})
   235  		Expect(m.DeleteStream(3)).To(Succeed())
   236  		mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   237  			Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2)))
   238  			checkFrameSerialization(f)
   239  		})
   240  		Expect(m.DeleteStream(4)).To(Succeed())
   241  	})
   242  
   243  	Context("using high stream limits", func() {
   244  		BeforeEach(func() { maxNumStreams = uint64(protocol.MaxStreamCount) - 2 })
   245  
   246  		It("doesn't send MAX_STREAMS frames if they would overflow 2^60 (the maximum stream count)", func() {
   247  			// open a bunch of streams
   248  			_, err := m.GetOrOpenStream(5)
   249  			Expect(err).ToNot(HaveOccurred())
   250  			// accept all streams
   251  			for i := 0; i < 5; i++ {
   252  				_, err := m.AcceptStream(context.Background())
   253  				Expect(err).ToNot(HaveOccurred())
   254  			}
   255  			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   256  				Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1))
   257  				checkFrameSerialization(f)
   258  			})
   259  			Expect(m.DeleteStream(4)).To(Succeed())
   260  			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   261  				Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount))
   262  				checkFrameSerialization(f)
   263  			})
   264  			Expect(m.DeleteStream(3)).To(Succeed())
   265  			// at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent
   266  			Expect(m.DeleteStream(2)).To(Succeed())
   267  			Expect(m.DeleteStream(1)).To(Succeed())
   268  		})
   269  	})
   270  
   271  	Context("randomized tests", func() {
   272  		const num = 1000
   273  
   274  		BeforeEach(func() { maxNumStreams = num })
   275  
   276  		It("opens and accepts streams", func() {
   277  			rand.Seed(GinkgoRandomSeed())
   278  			ids := make([]protocol.StreamNum, num)
   279  			for i := 0; i < num; i++ {
   280  				ids[i] = protocol.StreamNum(i + 1)
   281  			}
   282  			rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
   283  
   284  			const timeout = 5 * time.Second
   285  			done := make(chan struct{}, 2)
   286  			go func() {
   287  				defer GinkgoRecover()
   288  				ctx, cancel := context.WithTimeout(context.Background(), timeout)
   289  				defer cancel()
   290  				for i := 0; i < num; i++ {
   291  					_, err := m.AcceptStream(ctx)
   292  					Expect(err).ToNot(HaveOccurred())
   293  				}
   294  				done <- struct{}{}
   295  			}()
   296  
   297  			go func() {
   298  				defer GinkgoRecover()
   299  				for i := 0; i < num; i++ {
   300  					_, err := m.GetOrOpenStream(ids[i])
   301  					Expect(err).ToNot(HaveOccurred())
   302  				}
   303  				done <- struct{}{}
   304  			}()
   305  
   306  			Eventually(done, timeout*3/2).Should(Receive())
   307  			Eventually(done, timeout*3/2).Should(Receive())
   308  		})
   309  	})
   310  })