github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/streams_map_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  
     9  	"github.com/golang/mock/gomock"
    10  
    11  	"github.com/mikelsr/quic-go/internal/flowcontrol"
    12  	"github.com/mikelsr/quic-go/internal/mocks"
    13  	"github.com/mikelsr/quic-go/internal/protocol"
    14  	"github.com/mikelsr/quic-go/internal/qerr"
    15  	"github.com/mikelsr/quic-go/internal/wire"
    16  
    17  	. "github.com/onsi/ginkgo/v2"
    18  	. "github.com/onsi/gomega"
    19  )
    20  
    21  func (e streamError) TestError() error {
    22  	nums := make([]interface{}, len(e.nums))
    23  	for i, num := range e.nums {
    24  		nums[i] = num
    25  	}
    26  	return fmt.Errorf(e.message, nums...)
    27  }
    28  
    29  type streamMapping struct {
    30  	firstIncomingBidiStream protocol.StreamID
    31  	firstIncomingUniStream  protocol.StreamID
    32  	firstOutgoingBidiStream protocol.StreamID
    33  	firstOutgoingUniStream  protocol.StreamID
    34  }
    35  
    36  func expectTooManyStreamsError(err error) {
    37  	ExpectWithOffset(1, err).To(HaveOccurred())
    38  	ExpectWithOffset(1, err.Error()).To(Equal(errTooManyOpenStreams.Error()))
    39  	nerr, ok := err.(net.Error)
    40  	ExpectWithOffset(1, ok).To(BeTrue())
    41  	ExpectWithOffset(1, nerr.Timeout()).To(BeFalse())
    42  }
    43  
    44  var _ = Describe("Streams Map", func() {
    45  	newFlowController := func(protocol.StreamID) flowcontrol.StreamFlowController {
    46  		return mocks.NewMockStreamFlowController(mockCtrl)
    47  	}
    48  
    49  	serverStreamMapping := streamMapping{
    50  		firstIncomingBidiStream: 0,
    51  		firstOutgoingBidiStream: 1,
    52  		firstIncomingUniStream:  2,
    53  		firstOutgoingUniStream:  3,
    54  	}
    55  	clientStreamMapping := streamMapping{
    56  		firstIncomingBidiStream: 1,
    57  		firstOutgoingBidiStream: 0,
    58  		firstIncomingUniStream:  3,
    59  		firstOutgoingUniStream:  2,
    60  	}
    61  
    62  	for _, p := range []protocol.Perspective{protocol.PerspectiveServer, protocol.PerspectiveClient} {
    63  		perspective := p
    64  		var ids streamMapping
    65  		if perspective == protocol.PerspectiveClient {
    66  			ids = clientStreamMapping
    67  		} else {
    68  			ids = serverStreamMapping
    69  		}
    70  
    71  		Context(perspective.String(), func() {
    72  			var (
    73  				m          *streamsMap
    74  				mockSender *MockStreamSender
    75  			)
    76  
    77  			const (
    78  				MaxBidiStreamNum = 111
    79  				MaxUniStreamNum  = 222
    80  			)
    81  
    82  			allowUnlimitedStreams := func() {
    83  				m.UpdateLimits(&wire.TransportParameters{
    84  					MaxBidiStreamNum: protocol.MaxStreamCount,
    85  					MaxUniStreamNum:  protocol.MaxStreamCount,
    86  				})
    87  			}
    88  
    89  			BeforeEach(func() {
    90  				mockSender = NewMockStreamSender(mockCtrl)
    91  				m = newStreamsMap(mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective).(*streamsMap)
    92  			})
    93  
    94  			Context("opening", func() {
    95  				It("opens bidirectional streams", func() {
    96  					allowUnlimitedStreams()
    97  					str, err := m.OpenStream()
    98  					Expect(err).ToNot(HaveOccurred())
    99  					Expect(str).To(BeAssignableToTypeOf(&stream{}))
   100  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
   101  					str, err = m.OpenStream()
   102  					Expect(err).ToNot(HaveOccurred())
   103  					Expect(str).To(BeAssignableToTypeOf(&stream{}))
   104  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + 4))
   105  				})
   106  
   107  				It("opens unidirectional streams", func() {
   108  					allowUnlimitedStreams()
   109  					str, err := m.OpenUniStream()
   110  					Expect(err).ToNot(HaveOccurred())
   111  					Expect(str).To(BeAssignableToTypeOf(&sendStream{}))
   112  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
   113  					str, err = m.OpenUniStream()
   114  					Expect(err).ToNot(HaveOccurred())
   115  					Expect(str).To(BeAssignableToTypeOf(&sendStream{}))
   116  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + 4))
   117  				})
   118  			})
   119  
   120  			Context("accepting", func() {
   121  				It("accepts bidirectional streams", func() {
   122  					_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
   123  					Expect(err).ToNot(HaveOccurred())
   124  					str, err := m.AcceptStream(context.Background())
   125  					Expect(err).ToNot(HaveOccurred())
   126  					Expect(str).To(BeAssignableToTypeOf(&stream{}))
   127  					Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream))
   128  				})
   129  
   130  				It("accepts unidirectional streams", func() {
   131  					_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
   132  					Expect(err).ToNot(HaveOccurred())
   133  					str, err := m.AcceptUniStream(context.Background())
   134  					Expect(err).ToNot(HaveOccurred())
   135  					Expect(str).To(BeAssignableToTypeOf(&receiveStream{}))
   136  					Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream))
   137  				})
   138  			})
   139  
   140  			Context("deleting", func() {
   141  				BeforeEach(func() {
   142  					mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
   143  					allowUnlimitedStreams()
   144  				})
   145  
   146  				It("deletes outgoing bidirectional streams", func() {
   147  					id := ids.firstOutgoingBidiStream
   148  					str, err := m.OpenStream()
   149  					Expect(err).ToNot(HaveOccurred())
   150  					Expect(str.StreamID()).To(Equal(id))
   151  					Expect(m.DeleteStream(id)).To(Succeed())
   152  					dstr, err := m.GetOrOpenSendStream(id)
   153  					Expect(err).ToNot(HaveOccurred())
   154  					Expect(dstr).To(BeNil())
   155  				})
   156  
   157  				It("deletes incoming bidirectional streams", func() {
   158  					id := ids.firstIncomingBidiStream
   159  					str, err := m.GetOrOpenReceiveStream(id)
   160  					Expect(err).ToNot(HaveOccurred())
   161  					Expect(str.StreamID()).To(Equal(id))
   162  					Expect(m.DeleteStream(id)).To(Succeed())
   163  					dstr, err := m.GetOrOpenReceiveStream(id)
   164  					Expect(err).ToNot(HaveOccurred())
   165  					Expect(dstr).To(BeNil())
   166  				})
   167  
   168  				It("accepts bidirectional streams after they have been deleted", func() {
   169  					id := ids.firstIncomingBidiStream
   170  					_, err := m.GetOrOpenReceiveStream(id)
   171  					Expect(err).ToNot(HaveOccurred())
   172  					Expect(m.DeleteStream(id)).To(Succeed())
   173  					str, err := m.AcceptStream(context.Background())
   174  					Expect(err).ToNot(HaveOccurred())
   175  					Expect(str).ToNot(BeNil())
   176  					Expect(str.StreamID()).To(Equal(id))
   177  				})
   178  
   179  				It("deletes outgoing unidirectional streams", func() {
   180  					id := ids.firstOutgoingUniStream
   181  					str, err := m.OpenUniStream()
   182  					Expect(err).ToNot(HaveOccurred())
   183  					Expect(str.StreamID()).To(Equal(id))
   184  					Expect(m.DeleteStream(id)).To(Succeed())
   185  					dstr, err := m.GetOrOpenSendStream(id)
   186  					Expect(err).ToNot(HaveOccurred())
   187  					Expect(dstr).To(BeNil())
   188  				})
   189  
   190  				It("deletes incoming unidirectional streams", func() {
   191  					id := ids.firstIncomingUniStream
   192  					str, err := m.GetOrOpenReceiveStream(id)
   193  					Expect(err).ToNot(HaveOccurred())
   194  					Expect(str.StreamID()).To(Equal(id))
   195  					Expect(m.DeleteStream(id)).To(Succeed())
   196  					dstr, err := m.GetOrOpenReceiveStream(id)
   197  					Expect(err).ToNot(HaveOccurred())
   198  					Expect(dstr).To(BeNil())
   199  				})
   200  
   201  				It("accepts unirectional streams after they have been deleted", func() {
   202  					id := ids.firstIncomingUniStream
   203  					_, err := m.GetOrOpenReceiveStream(id)
   204  					Expect(err).ToNot(HaveOccurred())
   205  					Expect(m.DeleteStream(id)).To(Succeed())
   206  					str, err := m.AcceptUniStream(context.Background())
   207  					Expect(err).ToNot(HaveOccurred())
   208  					Expect(str).ToNot(BeNil())
   209  					Expect(str.StreamID()).To(Equal(id))
   210  				})
   211  
   212  				It("errors when deleting unknown incoming unidirectional streams", func() {
   213  					id := ids.firstIncomingUniStream + 4
   214  					Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id)))
   215  				})
   216  
   217  				It("errors when deleting unknown outgoing unidirectional streams", func() {
   218  					id := ids.firstOutgoingUniStream + 4
   219  					Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id)))
   220  				})
   221  
   222  				It("errors when deleting unknown incoming bidirectional streams", func() {
   223  					id := ids.firstIncomingBidiStream + 4
   224  					Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id)))
   225  				})
   226  
   227  				It("errors when deleting unknown outgoing bidirectional streams", func() {
   228  					id := ids.firstOutgoingBidiStream + 4
   229  					Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id)))
   230  				})
   231  			})
   232  
   233  			Context("getting streams", func() {
   234  				BeforeEach(func() {
   235  					allowUnlimitedStreams()
   236  				})
   237  
   238  				Context("send streams", func() {
   239  					It("gets an outgoing bidirectional stream", func() {
   240  						// need to open the stream ourselves first
   241  						// the peer is not allowed to create a stream initiated by us
   242  						_, err := m.OpenStream()
   243  						Expect(err).ToNot(HaveOccurred())
   244  						str, err := m.GetOrOpenSendStream(ids.firstOutgoingBidiStream)
   245  						Expect(err).ToNot(HaveOccurred())
   246  						Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
   247  					})
   248  
   249  					It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
   250  						id := ids.firstOutgoingBidiStream + 5*4
   251  						_, err := m.GetOrOpenSendStream(id)
   252  						Expect(err).To(MatchError(&qerr.TransportError{
   253  							ErrorCode:    qerr.StreamStateError,
   254  							ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id),
   255  						}))
   256  					})
   257  
   258  					It("gets an outgoing unidirectional stream", func() {
   259  						// need to open the stream ourselves first
   260  						// the peer is not allowed to create a stream initiated by us
   261  						_, err := m.OpenUniStream()
   262  						Expect(err).ToNot(HaveOccurred())
   263  						str, err := m.GetOrOpenSendStream(ids.firstOutgoingUniStream)
   264  						Expect(err).ToNot(HaveOccurred())
   265  						Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
   266  					})
   267  
   268  					It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
   269  						id := ids.firstOutgoingUniStream + 5*4
   270  						_, err := m.GetOrOpenSendStream(id)
   271  						Expect(err).To(MatchError(&qerr.TransportError{
   272  							ErrorCode:    qerr.StreamStateError,
   273  							ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id),
   274  						}))
   275  					})
   276  
   277  					It("gets an incoming bidirectional stream", func() {
   278  						id := ids.firstIncomingBidiStream + 4*7
   279  						str, err := m.GetOrOpenSendStream(id)
   280  						Expect(err).ToNot(HaveOccurred())
   281  						Expect(str.StreamID()).To(Equal(id))
   282  					})
   283  
   284  					It("errors when trying to get an incoming unidirectional stream", func() {
   285  						id := ids.firstIncomingUniStream
   286  						_, err := m.GetOrOpenSendStream(id)
   287  						Expect(err).To(MatchError(&qerr.TransportError{
   288  							ErrorCode:    qerr.StreamStateError,
   289  							ErrorMessage: fmt.Sprintf("peer attempted to open send stream %d", id),
   290  						}))
   291  					})
   292  				})
   293  
   294  				Context("receive streams", func() {
   295  					It("gets an outgoing bidirectional stream", func() {
   296  						// need to open the stream ourselves first
   297  						// the peer is not allowed to create a stream initiated by us
   298  						_, err := m.OpenStream()
   299  						Expect(err).ToNot(HaveOccurred())
   300  						str, err := m.GetOrOpenReceiveStream(ids.firstOutgoingBidiStream)
   301  						Expect(err).ToNot(HaveOccurred())
   302  						Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
   303  					})
   304  
   305  					It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
   306  						id := ids.firstOutgoingBidiStream + 5*4
   307  						_, err := m.GetOrOpenReceiveStream(id)
   308  						Expect(err).To(MatchError(&qerr.TransportError{
   309  							ErrorCode:    qerr.StreamStateError,
   310  							ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id),
   311  						}))
   312  					})
   313  
   314  					It("gets an incoming bidirectional stream", func() {
   315  						id := ids.firstIncomingBidiStream + 4*7
   316  						str, err := m.GetOrOpenReceiveStream(id)
   317  						Expect(err).ToNot(HaveOccurred())
   318  						Expect(str.StreamID()).To(Equal(id))
   319  					})
   320  
   321  					It("gets an incoming unidirectional stream", func() {
   322  						id := ids.firstIncomingUniStream + 4*10
   323  						str, err := m.GetOrOpenReceiveStream(id)
   324  						Expect(err).ToNot(HaveOccurred())
   325  						Expect(str.StreamID()).To(Equal(id))
   326  					})
   327  
   328  					It("errors when trying to get an outgoing unidirectional stream", func() {
   329  						id := ids.firstOutgoingUniStream
   330  						_, err := m.GetOrOpenReceiveStream(id)
   331  						Expect(err).To(MatchError(&qerr.TransportError{
   332  							ErrorCode:    qerr.StreamStateError,
   333  							ErrorMessage: fmt.Sprintf("peer attempted to open receive stream %d", id),
   334  						}))
   335  					})
   336  				})
   337  			})
   338  
   339  			It("processes the parameter for outgoing streams", func() {
   340  				mockSender.EXPECT().queueControlFrame(gomock.Any())
   341  				_, err := m.OpenStream()
   342  				expectTooManyStreamsError(err)
   343  				m.UpdateLimits(&wire.TransportParameters{
   344  					MaxBidiStreamNum: 5,
   345  					MaxUniStreamNum:  8,
   346  				})
   347  
   348  				mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2)
   349  				// test we can only 5 bidirectional streams
   350  				for i := 0; i < 5; i++ {
   351  					str, err := m.OpenStream()
   352  					Expect(err).ToNot(HaveOccurred())
   353  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + protocol.StreamID(4*i)))
   354  				}
   355  				_, err = m.OpenStream()
   356  				expectTooManyStreamsError(err)
   357  				// test we can only 8 unidirectional streams
   358  				for i := 0; i < 8; i++ {
   359  					str, err := m.OpenUniStream()
   360  					Expect(err).ToNot(HaveOccurred())
   361  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + protocol.StreamID(4*i)))
   362  				}
   363  				_, err = m.OpenUniStream()
   364  				expectTooManyStreamsError(err)
   365  			})
   366  
   367  			if perspective == protocol.PerspectiveClient {
   368  				It("applies parameters to existing streams (needed for 0-RTT)", func() {
   369  					m.UpdateLimits(&wire.TransportParameters{
   370  						MaxBidiStreamNum: 1000,
   371  						MaxUniStreamNum:  1000,
   372  					})
   373  					flowControllers := make(map[protocol.StreamID]*mocks.MockStreamFlowController)
   374  					m.newFlowController = func(id protocol.StreamID) flowcontrol.StreamFlowController {
   375  						fc := mocks.NewMockStreamFlowController(mockCtrl)
   376  						flowControllers[id] = fc
   377  						return fc
   378  					}
   379  
   380  					str, err := m.OpenStream()
   381  					Expect(err).ToNot(HaveOccurred())
   382  					unistr, err := m.OpenUniStream()
   383  					Expect(err).ToNot(HaveOccurred())
   384  
   385  					Expect(flowControllers).To(HaveKey(str.StreamID()))
   386  					flowControllers[str.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(4321))
   387  					Expect(flowControllers).To(HaveKey(unistr.StreamID()))
   388  					flowControllers[unistr.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(1234))
   389  
   390  					m.UpdateLimits(&wire.TransportParameters{
   391  						MaxBidiStreamNum:               1000,
   392  						InitialMaxStreamDataUni:        1234,
   393  						MaxUniStreamNum:                1000,
   394  						InitialMaxStreamDataBidiRemote: 4321,
   395  					})
   396  				})
   397  			}
   398  
   399  			Context("handling MAX_STREAMS frames", func() {
   400  				BeforeEach(func() {
   401  					mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
   402  				})
   403  
   404  				It("processes IDs for outgoing bidirectional streams", func() {
   405  					_, err := m.OpenStream()
   406  					expectTooManyStreamsError(err)
   407  					m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
   408  						Type:         protocol.StreamTypeBidi,
   409  						MaxStreamNum: 1,
   410  					})
   411  					str, err := m.OpenStream()
   412  					Expect(err).ToNot(HaveOccurred())
   413  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
   414  					_, err = m.OpenStream()
   415  					expectTooManyStreamsError(err)
   416  				})
   417  
   418  				It("processes IDs for outgoing unidirectional streams", func() {
   419  					_, err := m.OpenUniStream()
   420  					expectTooManyStreamsError(err)
   421  					m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
   422  						Type:         protocol.StreamTypeUni,
   423  						MaxStreamNum: 1,
   424  					})
   425  					str, err := m.OpenUniStream()
   426  					Expect(err).ToNot(HaveOccurred())
   427  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
   428  					_, err = m.OpenUniStream()
   429  					expectTooManyStreamsError(err)
   430  				})
   431  			})
   432  
   433  			Context("sending MAX_STREAMS frames", func() {
   434  				It("sends a MAX_STREAMS frame for bidirectional streams", func() {
   435  					_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
   436  					Expect(err).ToNot(HaveOccurred())
   437  					_, err = m.AcceptStream(context.Background())
   438  					Expect(err).ToNot(HaveOccurred())
   439  					mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
   440  						Type:         protocol.StreamTypeBidi,
   441  						MaxStreamNum: MaxBidiStreamNum + 1,
   442  					})
   443  					Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed())
   444  				})
   445  
   446  				It("sends a MAX_STREAMS frame for unidirectional streams", func() {
   447  					_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
   448  					Expect(err).ToNot(HaveOccurred())
   449  					_, err = m.AcceptUniStream(context.Background())
   450  					Expect(err).ToNot(HaveOccurred())
   451  					mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
   452  						Type:         protocol.StreamTypeUni,
   453  						MaxStreamNum: MaxUniStreamNum + 1,
   454  					})
   455  					Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed())
   456  				})
   457  			})
   458  
   459  			It("closes", func() {
   460  				testErr := errors.New("test error")
   461  				m.CloseWithError(testErr)
   462  				_, err := m.OpenStream()
   463  				Expect(err).To(HaveOccurred())
   464  				Expect(err.Error()).To(Equal(testErr.Error()))
   465  				_, err = m.OpenUniStream()
   466  				Expect(err).To(HaveOccurred())
   467  				Expect(err.Error()).To(Equal(testErr.Error()))
   468  				_, err = m.AcceptStream(context.Background())
   469  				Expect(err).To(HaveOccurred())
   470  				Expect(err.Error()).To(Equal(testErr.Error()))
   471  				_, err = m.AcceptUniStream(context.Background())
   472  				Expect(err).To(HaveOccurred())
   473  				Expect(err.Error()).To(Equal(testErr.Error()))
   474  			})
   475  
   476  			if perspective == protocol.PerspectiveClient {
   477  				It("resets for 0-RTT", func() {
   478  					mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
   479  					m.ResetFor0RTT()
   480  					// make sure that calls to open / accept streams fail
   481  					_, err := m.OpenStream()
   482  					Expect(err).To(MatchError(Err0RTTRejected))
   483  					_, err = m.AcceptStream(context.Background())
   484  					Expect(err).To(MatchError(Err0RTTRejected))
   485  					// make sure that we can still get new streams, as the server might be sending us data
   486  					str, err := m.GetOrOpenReceiveStream(3)
   487  					Expect(err).ToNot(HaveOccurred())
   488  					Expect(str).ToNot(BeNil())
   489  
   490  					// now switch to using the new streams map
   491  					m.UseResetMaps()
   492  					_, err = m.OpenStream()
   493  					Expect(err).To(HaveOccurred())
   494  					Expect(err.Error()).To(ContainSubstring("too many open streams"))
   495  				})
   496  			}
   497  		})
   498  	}
   499  })