github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/streams_map_outgoing_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math/rand"
     8  	"sort"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/mikelsr/quic-go/internal/protocol"
    13  	"github.com/mikelsr/quic-go/internal/wire"
    14  
    15  	"github.com/golang/mock/gomock"
    16  	. "github.com/onsi/ginkgo/v2"
    17  	. "github.com/onsi/gomega"
    18  )
    19  
    20  var _ = Describe("Streams Map (outgoing)", func() {
    21  	var (
    22  		m          *outgoingStreamsMap[*mockGenericStream]
    23  		newStr     func(num protocol.StreamNum) *mockGenericStream
    24  		mockSender *MockStreamSender
    25  	)
    26  
    27  	const streamType = 42
    28  
    29  	// waitForEnqueued waits until there are n go routines waiting on OpenStreamSync()
    30  	waitForEnqueued := func(n int) {
    31  		Eventually(func() int {
    32  			m.mutex.Lock()
    33  			defer m.mutex.Unlock()
    34  			return len(m.openQueue)
    35  		}, 50*time.Millisecond, 100*time.Microsecond).Should(Equal(n))
    36  	}
    37  
    38  	BeforeEach(func() {
    39  		newStr = func(num protocol.StreamNum) *mockGenericStream {
    40  			return &mockGenericStream{num: num}
    41  		}
    42  		mockSender = NewMockStreamSender(mockCtrl)
    43  		m = newOutgoingStreamsMap[*mockGenericStream](streamType, newStr, mockSender.queueControlFrame)
    44  	})
    45  
    46  	Context("no stream ID limit", func() {
    47  		BeforeEach(func() {
    48  			m.SetMaxStream(0xffffffff)
    49  		})
    50  
    51  		It("opens streams", func() {
    52  			str, err := m.OpenStream()
    53  			Expect(err).ToNot(HaveOccurred())
    54  			Expect(str.num).To(Equal(protocol.StreamNum(1)))
    55  			str, err = m.OpenStream()
    56  			Expect(err).ToNot(HaveOccurred())
    57  			Expect(str.num).To(Equal(protocol.StreamNum(2)))
    58  		})
    59  
    60  		It("doesn't open streams after it has been closed", func() {
    61  			testErr := errors.New("close")
    62  			m.CloseWithError(testErr)
    63  			_, err := m.OpenStream()
    64  			Expect(err).To(MatchError(testErr))
    65  		})
    66  
    67  		It("gets streams", func() {
    68  			_, err := m.OpenStream()
    69  			Expect(err).ToNot(HaveOccurred())
    70  			str, err := m.GetStream(1)
    71  			Expect(err).ToNot(HaveOccurred())
    72  			Expect(str.num).To(Equal(protocol.StreamNum(1)))
    73  		})
    74  
    75  		It("errors when trying to get a stream that has not yet been opened", func() {
    76  			_, err := m.GetStream(1)
    77  			Expect(err).To(HaveOccurred())
    78  			Expect(err.(streamError).TestError()).To(MatchError("peer attempted to open stream 1"))
    79  		})
    80  
    81  		It("deletes streams", func() {
    82  			_, err := m.OpenStream()
    83  			Expect(err).ToNot(HaveOccurred())
    84  			Expect(m.DeleteStream(1)).To(Succeed())
    85  			Expect(err).ToNot(HaveOccurred())
    86  			str, err := m.GetStream(1)
    87  			Expect(err).ToNot(HaveOccurred())
    88  			Expect(str).To(BeNil())
    89  		})
    90  
    91  		It("errors when deleting a non-existing stream", func() {
    92  			err := m.DeleteStream(1337)
    93  			Expect(err).To(HaveOccurred())
    94  			Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown outgoing stream 1337"))
    95  		})
    96  
    97  		It("errors when deleting a stream twice", func() {
    98  			_, err := m.OpenStream() // opens firstNewStream
    99  			Expect(err).ToNot(HaveOccurred())
   100  			Expect(m.DeleteStream(1)).To(Succeed())
   101  			err = m.DeleteStream(1)
   102  			Expect(err).To(HaveOccurred())
   103  			Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown outgoing stream 1"))
   104  		})
   105  
   106  		It("closes all streams when CloseWithError is called", func() {
   107  			str1, err := m.OpenStream()
   108  			Expect(err).ToNot(HaveOccurred())
   109  			str2, err := m.OpenStream()
   110  			Expect(err).ToNot(HaveOccurred())
   111  			testErr := errors.New("test err")
   112  			m.CloseWithError(testErr)
   113  			Expect(str1.closed).To(BeTrue())
   114  			Expect(str1.closeErr).To(MatchError(testErr))
   115  			Expect(str2.closed).To(BeTrue())
   116  			Expect(str2.closeErr).To(MatchError(testErr))
   117  		})
   118  
   119  		It("updates the send window", func() {
   120  			str1, err := m.OpenStream()
   121  			Expect(err).ToNot(HaveOccurred())
   122  			str2, err := m.OpenStream()
   123  			Expect(err).ToNot(HaveOccurred())
   124  			m.UpdateSendWindow(1337)
   125  			Expect(str1.sendWindow).To(BeEquivalentTo(1337))
   126  			Expect(str2.sendWindow).To(BeEquivalentTo(1337))
   127  		})
   128  	})
   129  
   130  	Context("with stream ID limits", func() {
   131  		It("errors when no stream can be opened immediately", func() {
   132  			mockSender.EXPECT().queueControlFrame(gomock.Any())
   133  			_, err := m.OpenStream()
   134  			expectTooManyStreamsError(err)
   135  		})
   136  
   137  		It("returns immediately when called with a canceled context", func() {
   138  			ctx, cancel := context.WithCancel(context.Background())
   139  			cancel()
   140  			_, err := m.OpenStreamSync(ctx)
   141  			Expect(err).To(MatchError("context canceled"))
   142  		})
   143  
   144  		It("blocks until a stream can be opened synchronously", func() {
   145  			mockSender.EXPECT().queueControlFrame(gomock.Any())
   146  			done := make(chan struct{})
   147  			go func() {
   148  				defer GinkgoRecover()
   149  				str, err := m.OpenStreamSync(context.Background())
   150  				Expect(err).ToNot(HaveOccurred())
   151  				Expect(str.num).To(Equal(protocol.StreamNum(1)))
   152  				close(done)
   153  			}()
   154  			waitForEnqueued(1)
   155  
   156  			m.SetMaxStream(1)
   157  			Eventually(done).Should(BeClosed())
   158  		})
   159  
   160  		It("unblocks when the context is canceled", func() {
   161  			mockSender.EXPECT().queueControlFrame(gomock.Any())
   162  			ctx, cancel := context.WithCancel(context.Background())
   163  			done := make(chan struct{})
   164  			go func() {
   165  				defer GinkgoRecover()
   166  				_, err := m.OpenStreamSync(ctx)
   167  				Expect(err).To(MatchError("context canceled"))
   168  				close(done)
   169  			}()
   170  			waitForEnqueued(1)
   171  
   172  			cancel()
   173  			Eventually(done).Should(BeClosed())
   174  
   175  			// make sure that the next stream opened is stream 1
   176  			m.SetMaxStream(1000)
   177  			str, err := m.OpenStream()
   178  			Expect(err).ToNot(HaveOccurred())
   179  			Expect(str.num).To(Equal(protocol.StreamNum(1)))
   180  		})
   181  
   182  		It("opens streams in the right order", func() {
   183  			mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
   184  			done1 := make(chan struct{})
   185  			go func() {
   186  				defer GinkgoRecover()
   187  				str, err := m.OpenStreamSync(context.Background())
   188  				Expect(err).ToNot(HaveOccurred())
   189  				Expect(str.num).To(Equal(protocol.StreamNum(1)))
   190  				close(done1)
   191  			}()
   192  			waitForEnqueued(1)
   193  
   194  			done2 := make(chan struct{})
   195  			go func() {
   196  				defer GinkgoRecover()
   197  				str, err := m.OpenStreamSync(context.Background())
   198  				Expect(err).ToNot(HaveOccurred())
   199  				Expect(str.num).To(Equal(protocol.StreamNum(2)))
   200  				close(done2)
   201  			}()
   202  			waitForEnqueued(2)
   203  
   204  			m.SetMaxStream(1)
   205  			Eventually(done1).Should(BeClosed())
   206  			Consistently(done2).ShouldNot(BeClosed())
   207  			m.SetMaxStream(2)
   208  			Eventually(done2).Should(BeClosed())
   209  		})
   210  
   211  		It("opens streams in the right order, when one of the contexts is canceled", func() {
   212  			mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
   213  			done1 := make(chan struct{})
   214  			go func() {
   215  				defer GinkgoRecover()
   216  				str, err := m.OpenStreamSync(context.Background())
   217  				Expect(err).ToNot(HaveOccurred())
   218  				Expect(str.num).To(Equal(protocol.StreamNum(1)))
   219  				close(done1)
   220  			}()
   221  			waitForEnqueued(1)
   222  
   223  			done2 := make(chan struct{})
   224  			ctx, cancel := context.WithCancel(context.Background())
   225  			go func() {
   226  				defer GinkgoRecover()
   227  				_, err := m.OpenStreamSync(ctx)
   228  				Expect(err).To(MatchError(context.Canceled))
   229  				close(done2)
   230  			}()
   231  			waitForEnqueued(2)
   232  
   233  			done3 := make(chan struct{})
   234  			go func() {
   235  				defer GinkgoRecover()
   236  				str, err := m.OpenStreamSync(context.Background())
   237  				Expect(err).ToNot(HaveOccurred())
   238  				Expect(str.num).To(Equal(protocol.StreamNum(2)))
   239  				close(done3)
   240  			}()
   241  			waitForEnqueued(3)
   242  
   243  			cancel()
   244  			Eventually(done2).Should(BeClosed())
   245  			m.SetMaxStream(1000)
   246  			Eventually(done1).Should(BeClosed())
   247  			Eventually(done3).Should(BeClosed())
   248  		})
   249  
   250  		It("unblocks multiple OpenStreamSync calls at the same time", func() {
   251  			mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
   252  			done := make(chan struct{})
   253  			go func() {
   254  				defer GinkgoRecover()
   255  				_, err := m.OpenStreamSync(context.Background())
   256  				Expect(err).ToNot(HaveOccurred())
   257  				done <- struct{}{}
   258  			}()
   259  			go func() {
   260  				defer GinkgoRecover()
   261  				_, err := m.OpenStreamSync(context.Background())
   262  				Expect(err).ToNot(HaveOccurred())
   263  				done <- struct{}{}
   264  			}()
   265  			waitForEnqueued(2)
   266  			go func() {
   267  				defer GinkgoRecover()
   268  				_, err := m.OpenStreamSync(context.Background())
   269  				Expect(err).To(MatchError("test done"))
   270  				done <- struct{}{}
   271  			}()
   272  			waitForEnqueued(3)
   273  
   274  			m.SetMaxStream(2)
   275  			Eventually(done).Should(Receive())
   276  			Eventually(done).Should(Receive())
   277  			Consistently(done).ShouldNot(Receive())
   278  
   279  			m.CloseWithError(errors.New("test done"))
   280  			Eventually(done).Should(Receive())
   281  		})
   282  
   283  		It("returns an error for OpenStream while an OpenStreamSync call is blocking", func() {
   284  			mockSender.EXPECT().queueControlFrame(gomock.Any()).MaxTimes(2)
   285  			openedSync := make(chan struct{})
   286  			go func() {
   287  				defer GinkgoRecover()
   288  				str, err := m.OpenStreamSync(context.Background())
   289  				Expect(err).ToNot(HaveOccurred())
   290  				Expect(str.num).To(Equal(protocol.StreamNum(1)))
   291  				close(openedSync)
   292  			}()
   293  			waitForEnqueued(1)
   294  
   295  			start := make(chan struct{})
   296  			openend := make(chan struct{})
   297  			go func() {
   298  				defer GinkgoRecover()
   299  				var hasStarted bool
   300  				for {
   301  					str, err := m.OpenStream()
   302  					if err == nil {
   303  						Expect(str.num).To(Equal(protocol.StreamNum(2)))
   304  						close(openend)
   305  						return
   306  					}
   307  					expectTooManyStreamsError(err)
   308  					if !hasStarted {
   309  						close(start)
   310  						hasStarted = true
   311  					}
   312  				}
   313  			}()
   314  
   315  			Eventually(start).Should(BeClosed())
   316  			m.SetMaxStream(1)
   317  			Eventually(openedSync).Should(BeClosed())
   318  			Consistently(openend).ShouldNot(BeClosed())
   319  			m.SetMaxStream(2)
   320  			Eventually(openend).Should(BeClosed())
   321  		})
   322  
   323  		It("stops opening synchronously when it is closed", func() {
   324  			mockSender.EXPECT().queueControlFrame(gomock.Any())
   325  			testErr := errors.New("test error")
   326  			done := make(chan struct{})
   327  			go func() {
   328  				defer GinkgoRecover()
   329  				_, err := m.OpenStreamSync(context.Background())
   330  				Expect(err).To(MatchError(testErr))
   331  				close(done)
   332  			}()
   333  
   334  			Consistently(done).ShouldNot(BeClosed())
   335  			m.CloseWithError(testErr)
   336  			Eventually(done).Should(BeClosed())
   337  		})
   338  
   339  		It("doesn't reduce the stream limit", func() {
   340  			m.SetMaxStream(2)
   341  			m.SetMaxStream(1)
   342  			_, err := m.OpenStream()
   343  			Expect(err).ToNot(HaveOccurred())
   344  			str, err := m.OpenStream()
   345  			Expect(err).ToNot(HaveOccurred())
   346  			Expect(str.num).To(Equal(protocol.StreamNum(2)))
   347  		})
   348  
   349  		It("queues a STREAMS_BLOCKED frame if no stream can be opened", func() {
   350  			m.SetMaxStream(6)
   351  			// open the 6 allowed streams
   352  			for i := 0; i < 6; i++ {
   353  				_, err := m.OpenStream()
   354  				Expect(err).ToNot(HaveOccurred())
   355  			}
   356  
   357  			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   358  				bf := f.(*wire.StreamsBlockedFrame)
   359  				Expect(bf.Type).To(BeEquivalentTo(streamType))
   360  				Expect(bf.StreamLimit).To(BeEquivalentTo(6))
   361  			})
   362  			_, err := m.OpenStream()
   363  			Expect(err).To(HaveOccurred())
   364  			Expect(err.Error()).To(Equal(errTooManyOpenStreams.Error()))
   365  		})
   366  
   367  		It("only sends one STREAMS_BLOCKED frame for one stream ID", func() {
   368  			m.SetMaxStream(1)
   369  			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   370  				Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1))
   371  			})
   372  			_, err := m.OpenStream()
   373  			Expect(err).ToNot(HaveOccurred())
   374  			// try to open a stream twice, but expect only one STREAMS_BLOCKED to be sent
   375  			_, err = m.OpenStream()
   376  			expectTooManyStreamsError(err)
   377  			_, err = m.OpenStream()
   378  			expectTooManyStreamsError(err)
   379  		})
   380  
   381  		It("queues a STREAMS_BLOCKED frame when there more streams waiting for OpenStreamSync than MAX_STREAMS allows", func() {
   382  			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   383  				Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(0))
   384  			})
   385  			done := make(chan struct{}, 2)
   386  			go func() {
   387  				defer GinkgoRecover()
   388  				_, err := m.OpenStreamSync(context.Background())
   389  				Expect(err).ToNot(HaveOccurred())
   390  				done <- struct{}{}
   391  			}()
   392  			go func() {
   393  				defer GinkgoRecover()
   394  				_, err := m.OpenStreamSync(context.Background())
   395  				Expect(err).ToNot(HaveOccurred())
   396  				done <- struct{}{}
   397  			}()
   398  			waitForEnqueued(2)
   399  
   400  			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   401  				Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1))
   402  			})
   403  			m.SetMaxStream(1)
   404  			Eventually(done).Should(Receive())
   405  			Consistently(done).ShouldNot(Receive())
   406  			m.SetMaxStream(2)
   407  			Eventually(done).Should(Receive())
   408  		})
   409  	})
   410  
   411  	Context("randomized tests", func() {
   412  		It("opens streams", func() {
   413  			rand.Seed(GinkgoRandomSeed())
   414  			const n = 100
   415  			fmt.Fprintf(GinkgoWriter, "Opening %d streams concurrently.\n", n)
   416  
   417  			var blockedAt []protocol.StreamNum
   418  			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   419  				blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit)
   420  			}).AnyTimes()
   421  			done := make(map[int]chan struct{})
   422  			for i := 1; i <= n; i++ {
   423  				c := make(chan struct{})
   424  				done[i] = c
   425  
   426  				go func(doneChan chan struct{}, id protocol.StreamNum) {
   427  					defer GinkgoRecover()
   428  					defer close(doneChan)
   429  					str, err := m.OpenStreamSync(context.Background())
   430  					Expect(err).ToNot(HaveOccurred())
   431  					Expect(str.num).To(Equal(id))
   432  				}(c, protocol.StreamNum(i))
   433  				waitForEnqueued(i)
   434  			}
   435  
   436  			var limit int
   437  			limits := []protocol.StreamNum{0}
   438  			for limit < n {
   439  				limit += rand.Intn(n/5) + 1
   440  				if limit <= n {
   441  					limits = append(limits, protocol.StreamNum(limit))
   442  				}
   443  				fmt.Fprintf(GinkgoWriter, "Setting stream limit to %d.\n", limit)
   444  				m.SetMaxStream(protocol.StreamNum(limit))
   445  				for i := 1; i <= n; i++ {
   446  					if i <= limit {
   447  						Eventually(done[i]).Should(BeClosed())
   448  					} else {
   449  						Expect(done[i]).ToNot(BeClosed())
   450  					}
   451  				}
   452  				str, err := m.OpenStream()
   453  				if limit <= n {
   454  					Expect(err).To(HaveOccurred())
   455  					Expect(err.Error()).To(Equal(errTooManyOpenStreams.Error()))
   456  				} else {
   457  					Expect(str.num).To(Equal(protocol.StreamNum(n + 1)))
   458  				}
   459  			}
   460  			Expect(blockedAt).To(Equal(limits))
   461  		})
   462  
   463  		It("opens streams, when some of them are getting canceled", func() {
   464  			rand.Seed(GinkgoRandomSeed())
   465  			const n = 100
   466  			fmt.Fprintf(GinkgoWriter, "Opening %d streams concurrently.\n", n)
   467  
   468  			var blockedAt []protocol.StreamNum
   469  			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
   470  				blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit)
   471  			}).AnyTimes()
   472  
   473  			ctx, cancel := context.WithCancel(context.Background())
   474  			streamsToCancel := make(map[protocol.StreamNum]struct{}) // used as a set
   475  			for i := 0; i < 10; i++ {
   476  				id := protocol.StreamNum(rand.Intn(n) + 1)
   477  				fmt.Fprintf(GinkgoWriter, "Canceling stream %d.\n", id)
   478  				streamsToCancel[id] = struct{}{}
   479  			}
   480  
   481  			streamWillBeCanceled := func(id protocol.StreamNum) bool {
   482  				_, ok := streamsToCancel[id]
   483  				return ok
   484  			}
   485  
   486  			var streamIDs []int
   487  			var mutex sync.Mutex
   488  			done := make(map[int]chan struct{})
   489  			for i := 1; i <= n; i++ {
   490  				c := make(chan struct{})
   491  				done[i] = c
   492  
   493  				go func(doneChan chan struct{}, id protocol.StreamNum) {
   494  					defer GinkgoRecover()
   495  					defer close(doneChan)
   496  					cont := context.Background()
   497  					if streamWillBeCanceled(id) {
   498  						cont = ctx
   499  					}
   500  					str, err := m.OpenStreamSync(cont)
   501  					if streamWillBeCanceled(id) {
   502  						Expect(err).To(MatchError(context.Canceled))
   503  						return
   504  					}
   505  					Expect(err).ToNot(HaveOccurred())
   506  					mutex.Lock()
   507  					streamIDs = append(streamIDs, int(str.num))
   508  					mutex.Unlock()
   509  				}(c, protocol.StreamNum(i))
   510  				waitForEnqueued(i)
   511  			}
   512  
   513  			cancel()
   514  			for id := range streamsToCancel {
   515  				Eventually(done[int(id)]).Should(BeClosed())
   516  			}
   517  			var limit int
   518  			numStreams := n - len(streamsToCancel)
   519  			var limits []protocol.StreamNum
   520  			for limit < numStreams {
   521  				limits = append(limits, protocol.StreamNum(limit))
   522  				limit += rand.Intn(n/5) + 1
   523  				fmt.Fprintf(GinkgoWriter, "Setting stream limit to %d.\n", limit)
   524  				m.SetMaxStream(protocol.StreamNum(limit))
   525  				l := limit
   526  				if l > numStreams {
   527  					l = numStreams
   528  				}
   529  				Eventually(func() int {
   530  					mutex.Lock()
   531  					defer mutex.Unlock()
   532  					return len(streamIDs)
   533  				}).Should(Equal(l))
   534  				// check that all stream IDs were used
   535  				Expect(streamIDs).To(HaveLen(l))
   536  				sort.Ints(streamIDs)
   537  				for i := 0; i < l; i++ {
   538  					Expect(streamIDs[i]).To(Equal(i + 1))
   539  				}
   540  			}
   541  			Expect(blockedAt).To(Equal(limits))
   542  		})
   543  	})
   544  })