github.com/tumi8/quic-go@v0.37.4-tum/streams_map_incoming_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"time"
     7  
     8  	"golang.org/x/exp/rand"
     9  
    10  	"github.com/tumi8/quic-go/noninternal/protocol"
    11  	"github.com/tumi8/quic-go/noninternal/wire"
    12  
    13  	"github.com/golang/mock/gomock"
    14  	. "github.com/onsi/ginkgo/v2"
    15  	. "github.com/onsi/gomega"
    16  )
    17  
    18  type mockGenericStream struct {
    19  	num protocol.StreamNum
    20  
    21  	closed     bool
    22  	closeErr   error
    23  	sendWindow protocol.ByteCount
    24  }
    25  
    26  func (s *mockGenericStream) closeForShutdown(err error) {
    27  	s.closed = true
    28  	s.closeErr = err
    29  }
    30  
    31  func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) {
    32  	s.sendWindow = limit
    33  }
    34  
    35  var _ = Describe("Streams Map (incoming)", func() {
    36  	var (
    37  		m              *incomingStreamsMap[*mockGenericStream]
    38  		newItemCounter int
    39  		mockSender     *MockStreamSender
    40  		maxNumStreams  uint64
    41  	)
    42  	streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeUni}[rand.Intn(2)]
    43  
    44  	// check that the frame can be serialized and deserialized
    45  	checkFrameSerialization := func(f wire.Frame) {
    46  		b, err := f.Append(nil, protocol.Version1)
    47  		ExpectWithOffset(1, err).ToNot(HaveOccurred())
    48  		_, frame, err := wire.NewFrameParser(false).ParseNext(b, protocol.Encryption1RTT, protocol.Version1)
    49  		ExpectWithOffset(1, err).ToNot(HaveOccurred())
    50  		Expect(f).To(Equal(frame))
    51  	}
    52  
    53  	BeforeEach(func() { maxNumStreams = 5 })
    54  
    55  	JustBeforeEach(func() {
    56  		newItemCounter = 0
    57  		mockSender = NewMockStreamSender(mockCtrl)
    58  		m = newIncomingStreamsMap(
    59  			streamType,
    60  			func(num protocol.StreamNum) *mockGenericStream {
    61  				newItemCounter++
    62  				return &mockGenericStream{num: num}
    63  			},
    64  			maxNumStreams,
    65  			mockSender.queueControlFrame,
    66  		)
    67  	})
    68  
    69  	It("opens all streams up to the id on GetOrOpenStream", func() {
    70  		_, err := m.GetOrOpenStream(4)
    71  		Expect(err).ToNot(HaveOccurred())
    72  		Expect(newItemCounter).To(Equal(4))
    73  	})
    74  
    75  	It("starts opening streams at the right position", func() {
    76  		// like the test above, but with 2 calls to GetOrOpenStream
    77  		_, err := m.GetOrOpenStream(2)
    78  		Expect(err).ToNot(HaveOccurred())
    79  		Expect(newItemCounter).To(Equal(2))
    80  		_, err = m.GetOrOpenStream(5)
    81  		Expect(err).ToNot(HaveOccurred())
    82  		Expect(newItemCounter).To(Equal(5))
    83  	})
    84  
    85  	It("accepts streams in the right order", func() {
    86  		_, err := m.GetOrOpenStream(2) // open streams 1 and 2
    87  		Expect(err).ToNot(HaveOccurred())
    88  		str, err := m.AcceptStream(context.Background())
    89  		Expect(err).ToNot(HaveOccurred())
    90  		Expect(str.num).To(Equal(protocol.StreamNum(1)))
    91  		str, err = m.AcceptStream(context.Background())
    92  		Expect(err).ToNot(HaveOccurred())
    93  		Expect(str.num).To(Equal(protocol.StreamNum(2)))
    94  	})
    95  
    96  	It("allows opening the maximum stream ID", func() {
    97  		str, err := m.GetOrOpenStream(1)
    98  		Expect(err).ToNot(HaveOccurred())
    99  		Expect(str.num).To(Equal(protocol.StreamNum(1)))
   100  	})
   101  
   102  	It("errors when trying to get a stream ID higher than the maximum", func() {
   103  		_, err := m.GetOrOpenStream(6)
   104  		Expect(err).To(HaveOccurred())
   105  		Expect(err.(streamError).TestError()).To(MatchError("peer tried to open stream 6 (current limit: 5)"))
   106  	})
   107  
   108  	It("blocks AcceptStream until a new stream is available", func() {
   109  		strChan := make(chan *mockGenericStream)
   110  		go func() {
   111  			defer GinkgoRecover()
   112  			str, err := m.AcceptStream(context.Background())
   113  			Expect(err).ToNot(HaveOccurred())
   114  			strChan <- str
   115  		}()
   116  		Consistently(strChan).ShouldNot(Receive())
   117  		str, err := m.GetOrOpenStream(1)
   118  		Expect(err).ToNot(HaveOccurred())
   119  		Expect(str.num).To(Equal(protocol.StreamNum(1)))
   120  		var acceptedStr *mockGenericStream
   121  		Eventually(strChan).Should(Receive(&acceptedStr))
   122  		Expect(acceptedStr.num).To(Equal(protocol.StreamNum(1)))
   123  	})
   124  
   125  	It("unblocks AcceptStream when the context is canceled", func() {
   126  		ctx, cancel := context.WithCancel(context.Background())
   127  		done := make(chan struct{})
   128  		go func() {
   129  			defer GinkgoRecover()
   130  			_, err := m.AcceptStream(ctx)
   131  			Expect(err).To(MatchError("context canceled"))
   132  			close(done)
   133  		}()
   134  		Consistently(done).ShouldNot(BeClosed())
   135  		cancel()
   136  		Eventually(done).Should(BeClosed())
   137  	})
   138  
   139  	It("unblocks AcceptStream when it is closed", func() {
   140  		testErr := errors.New("test error")
   141  		done := make(chan struct{})
   142  		go func() {
   143  			defer GinkgoRecover()
   144  			_, err := m.AcceptStream(context.Background())
   145  			Expect(err).To(MatchError(testErr))
   146  			close(done)
   147  		}()
   148  		Consistently(done).ShouldNot(BeClosed())
   149  		m.CloseWithError(testErr)
   150  		Eventually(done).Should(BeClosed())
   151  	})
   152  
   153  	It("errors AcceptStream immediately if it is closed", func() {
   154  		testErr := errors.New("test error")
   155  		m.CloseWithError(testErr)
   156  		_, err := m.AcceptStream(context.Background())
   157  		Expect(err).To(MatchError(testErr))
   158  	})
   159  
   160  	It("closes all streams when CloseWithError is called", func() {
   161  		str1, err := m.GetOrOpenStream(1)
   162  		Expect(err).ToNot(HaveOccurred())
   163  		str2, err := m.GetOrOpenStream(3)
   164  		Expect(err).ToNot(HaveOccurred())
   165  		testErr := errors.New("test err")
   166  		m.CloseWithError(testErr)
   167  		Expect(str1.closed).To(BeTrue())
   168  		Expect(str1.closeErr).To(MatchError(testErr))
   169  		Expect(str2.closed).To(BeTrue())
   170  		Expect(str2.closeErr).To(MatchError(testErr))
   171  	})
   172  
   173  	It("deletes streams", func() {
   174  		mockSender.EXPECT().queueControlFrame(gomock.Any())
   175  		_, err := m.GetOrOpenStream(1)
   176  		Expect(err).ToNot(HaveOccurred())
   177  		str, err := m.AcceptStream(context.Background())
   178  		Expect(err).ToNot(HaveOccurred())
   179  		Expect(str.num).To(Equal(protocol.StreamNum(1)))
   180  		Expect(m.DeleteStream(1)).To(Succeed())
   181  		str, err = m.GetOrOpenStream(1)
   182  		Expect(err).ToNot(HaveOccurred())
   183  		Expect(str).To(BeNil())
   184  	})
   185  
   186  	It("waits until a stream is accepted before actually deleting it", func() {
   187  		_, err := m.GetOrOpenStream(2)
   188  		Expect(err).ToNot(HaveOccurred())
   189  		Expect(m.DeleteStream(2)).To(Succeed())
   190  		str, err := m.AcceptStream(context.Background())
   191  		Expect(err).ToNot(HaveOccurred())
   192  		Expect(str.num).To(Equal(protocol.StreamNum(1)))
   193  		// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
   194  		mockSender.EXPECT().queueControlFrame(gomock.Any())
   195  		str, err = m.AcceptStream(context.Background())
   196  		Expect(err).ToNot(HaveOccurred())
   197  		Expect(str.num).To(Equal(protocol.StreamNum(2)))
   198  	})
   199  
   200  	It("doesn't return a stream queued for deleting from GetOrOpenStream", func() {
   201  		str, err := m.GetOrOpenStream(1)
   202  		Expect(err).ToNot(HaveOccurred())
   203  		Expect(str).ToNot(BeNil())
   204  		Expect(m.DeleteStream(1)).To(Succeed())
   205  		str, err = m.GetOrOpenStream(1)
   206  		Expect(err).ToNot(HaveOccurred())
   207  		Expect(str).To(BeNil())
   208  		// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
   209  		mockSender.EXPECT().queueControlFrame(gomock.Any())
   210  		str, err = m.AcceptStream(context.Background())
   211  		Expect(err).ToNot(HaveOccurred())
   212  		Expect(str).ToNot(BeNil())
   213  	})
   214  
   215  	It("errors when deleting a non-existing stream", func() {
   216  		err := m.DeleteStream(1337)
   217  		Expect(err).To(HaveOccurred())
   218  		Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown incoming stream 1337"))
   219  	})
   220  
   221  	It("sends MAX_STREAMS frames when streams are deleted", func() {
   222  		// open a bunch of streams
   223  		_, err := m.GetOrOpenStream(5)
   224  		Expect(err).ToNot(HaveOccurred())
   225  		// accept all streams
   226  		for i := 0; i < 5; i++ {
   227  			_, err := m.AcceptStream(context.Background())
   228  			Expect(err).ToNot(HaveOccurred())
   229  		}
   230  		mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   231  			msf := f.(*wire.MaxStreamsFrame)
   232  			Expect(msf.Type).To(BeEquivalentTo(streamType))
   233  			Expect(msf.MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1)))
   234  			checkFrameSerialization(f)
   235  		})
   236  		Expect(m.DeleteStream(3)).To(Succeed())
   237  		mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   238  			Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2)))
   239  			checkFrameSerialization(f)
   240  		})
   241  		Expect(m.DeleteStream(4)).To(Succeed())
   242  	})
   243  
   244  	Context("using high stream limits", func() {
   245  		BeforeEach(func() { maxNumStreams = uint64(protocol.MaxStreamCount) - 2 })
   246  
   247  		It("doesn't send MAX_STREAMS frames if they would overflow 2^60 (the maximum stream count)", func() {
   248  			// open a bunch of streams
   249  			_, err := m.GetOrOpenStream(5)
   250  			Expect(err).ToNot(HaveOccurred())
   251  			// accept all streams
   252  			for i := 0; i < 5; i++ {
   253  				_, err := m.AcceptStream(context.Background())
   254  				Expect(err).ToNot(HaveOccurred())
   255  			}
   256  			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   257  				Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1))
   258  				checkFrameSerialization(f)
   259  			})
   260  			Expect(m.DeleteStream(4)).To(Succeed())
   261  			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   262  				Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount))
   263  				checkFrameSerialization(f)
   264  			})
   265  			Expect(m.DeleteStream(3)).To(Succeed())
   266  			// at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent
   267  			Expect(m.DeleteStream(2)).To(Succeed())
   268  			Expect(m.DeleteStream(1)).To(Succeed())
   269  		})
   270  	})
   271  
   272  	Context("randomized tests", func() {
   273  		const num = 1000
   274  
   275  		BeforeEach(func() { maxNumStreams = num })
   276  
   277  		It("opens and accepts streams", func() {
   278  			rand.Seed(uint64(GinkgoRandomSeed()))
   279  			ids := make([]protocol.StreamNum, num)
   280  			for i := 0; i < num; i++ {
   281  				ids[i] = protocol.StreamNum(i + 1)
   282  			}
   283  			rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
   284  
   285  			const timeout = 5 * time.Second
   286  			done := make(chan struct{}, 2)
   287  			go func() {
   288  				defer GinkgoRecover()
   289  				ctx, cancel := context.WithTimeout(context.Background(), timeout)
   290  				defer cancel()
   291  				for i := 0; i < num; i++ {
   292  					_, err := m.AcceptStream(ctx)
   293  					Expect(err).ToNot(HaveOccurred())
   294  				}
   295  				done <- struct{}{}
   296  			}()
   297  
   298  			go func() {
   299  				defer GinkgoRecover()
   300  				for i := 0; i < num; i++ {
   301  					_, err := m.GetOrOpenStream(ids[i])
   302  					Expect(err).ToNot(HaveOccurred())
   303  				}
   304  				done <- struct{}{}
   305  			}()
   306  
   307  			Eventually(done, timeout*3/2).Should(Receive())
   308  			Eventually(done, timeout*3/2).Should(Receive())
   309  		})
   310  	})
   311  })