github.com/MerlinKodo/quic-go@v0.39.2/streams_map_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  
     9  	"github.com/MerlinKodo/quic-go/internal/flowcontrol"
    10  	"github.com/MerlinKodo/quic-go/internal/mocks"
    11  	"github.com/MerlinKodo/quic-go/internal/protocol"
    12  	"github.com/MerlinKodo/quic-go/internal/qerr"
    13  	"github.com/MerlinKodo/quic-go/internal/wire"
    14  
    15  	. "github.com/onsi/ginkgo/v2"
    16  	. "github.com/onsi/gomega"
    17  	"go.uber.org/mock/gomock"
    18  )
    19  
    20  func (e streamError) TestError() error {
    21  	nums := make([]interface{}, len(e.nums))
    22  	for i, num := range e.nums {
    23  		nums[i] = num
    24  	}
    25  	return fmt.Errorf(e.message, nums...)
    26  }
    27  
    28  type streamMapping struct {
    29  	firstIncomingBidiStream protocol.StreamID
    30  	firstIncomingUniStream  protocol.StreamID
    31  	firstOutgoingBidiStream protocol.StreamID
    32  	firstOutgoingUniStream  protocol.StreamID
    33  }
    34  
    35  func expectTooManyStreamsError(err error) {
    36  	ExpectWithOffset(1, err).To(HaveOccurred())
    37  	ExpectWithOffset(1, err.Error()).To(Equal(errTooManyOpenStreams.Error()))
    38  	nerr, ok := err.(net.Error)
    39  	ExpectWithOffset(1, ok).To(BeTrue())
    40  	ExpectWithOffset(1, nerr.Timeout()).To(BeFalse())
    41  }
    42  
    43  var _ = Describe("Streams Map", func() {
    44  	newFlowController := func(protocol.StreamID) flowcontrol.StreamFlowController {
    45  		return mocks.NewMockStreamFlowController(mockCtrl)
    46  	}
    47  
    48  	serverStreamMapping := streamMapping{
    49  		firstIncomingBidiStream: 0,
    50  		firstOutgoingBidiStream: 1,
    51  		firstIncomingUniStream:  2,
    52  		firstOutgoingUniStream:  3,
    53  	}
    54  	clientStreamMapping := streamMapping{
    55  		firstIncomingBidiStream: 1,
    56  		firstOutgoingBidiStream: 0,
    57  		firstIncomingUniStream:  3,
    58  		firstOutgoingUniStream:  2,
    59  	}
    60  
    61  	for _, p := range []protocol.Perspective{protocol.PerspectiveServer, protocol.PerspectiveClient} {
    62  		perspective := p
    63  		var ids streamMapping
    64  		if perspective == protocol.PerspectiveClient {
    65  			ids = clientStreamMapping
    66  		} else {
    67  			ids = serverStreamMapping
    68  		}
    69  
    70  		Context(perspective.String(), func() {
    71  			var (
    72  				m          *streamsMap
    73  				mockSender *MockStreamSender
    74  			)
    75  
    76  			const (
    77  				MaxBidiStreamNum = 111
    78  				MaxUniStreamNum  = 222
    79  			)
    80  
    81  			allowUnlimitedStreams := func() {
    82  				m.UpdateLimits(&wire.TransportParameters{
    83  					MaxBidiStreamNum: protocol.MaxStreamCount,
    84  					MaxUniStreamNum:  protocol.MaxStreamCount,
    85  				})
    86  			}
    87  
    88  			BeforeEach(func() {
    89  				mockSender = NewMockStreamSender(mockCtrl)
    90  				m = newStreamsMap(mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective).(*streamsMap)
    91  			})
    92  
    93  			Context("opening", func() {
    94  				It("opens bidirectional streams", func() {
    95  					allowUnlimitedStreams()
    96  					str, err := m.OpenStream()
    97  					Expect(err).ToNot(HaveOccurred())
    98  					Expect(str).To(BeAssignableToTypeOf(&stream{}))
    99  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
   100  					str, err = m.OpenStream()
   101  					Expect(err).ToNot(HaveOccurred())
   102  					Expect(str).To(BeAssignableToTypeOf(&stream{}))
   103  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + 4))
   104  				})
   105  
   106  				It("opens unidirectional streams", func() {
   107  					allowUnlimitedStreams()
   108  					str, err := m.OpenUniStream()
   109  					Expect(err).ToNot(HaveOccurred())
   110  					Expect(str).To(BeAssignableToTypeOf(&sendStream{}))
   111  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
   112  					str, err = m.OpenUniStream()
   113  					Expect(err).ToNot(HaveOccurred())
   114  					Expect(str).To(BeAssignableToTypeOf(&sendStream{}))
   115  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + 4))
   116  				})
   117  			})
   118  
   119  			Context("accepting", func() {
   120  				It("accepts bidirectional streams", func() {
   121  					_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
   122  					Expect(err).ToNot(HaveOccurred())
   123  					str, err := m.AcceptStream(context.Background())
   124  					Expect(err).ToNot(HaveOccurred())
   125  					Expect(str).To(BeAssignableToTypeOf(&stream{}))
   126  					Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream))
   127  				})
   128  
   129  				It("accepts unidirectional streams", func() {
   130  					_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
   131  					Expect(err).ToNot(HaveOccurred())
   132  					str, err := m.AcceptUniStream(context.Background())
   133  					Expect(err).ToNot(HaveOccurred())
   134  					Expect(str).To(BeAssignableToTypeOf(&receiveStream{}))
   135  					Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream))
   136  				})
   137  			})
   138  
   139  			Context("deleting", func() {
   140  				BeforeEach(func() {
   141  					mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
   142  					allowUnlimitedStreams()
   143  				})
   144  
   145  				It("deletes outgoing bidirectional streams", func() {
   146  					id := ids.firstOutgoingBidiStream
   147  					str, err := m.OpenStream()
   148  					Expect(err).ToNot(HaveOccurred())
   149  					Expect(str.StreamID()).To(Equal(id))
   150  					Expect(m.DeleteStream(id)).To(Succeed())
   151  					dstr, err := m.GetOrOpenSendStream(id)
   152  					Expect(err).ToNot(HaveOccurred())
   153  					Expect(dstr).To(BeNil())
   154  				})
   155  
   156  				It("deletes incoming bidirectional streams", func() {
   157  					id := ids.firstIncomingBidiStream
   158  					str, err := m.GetOrOpenReceiveStream(id)
   159  					Expect(err).ToNot(HaveOccurred())
   160  					Expect(str.StreamID()).To(Equal(id))
   161  					Expect(m.DeleteStream(id)).To(Succeed())
   162  					dstr, err := m.GetOrOpenReceiveStream(id)
   163  					Expect(err).ToNot(HaveOccurred())
   164  					Expect(dstr).To(BeNil())
   165  				})
   166  
   167  				It("accepts bidirectional streams after they have been deleted", func() {
   168  					id := ids.firstIncomingBidiStream
   169  					_, err := m.GetOrOpenReceiveStream(id)
   170  					Expect(err).ToNot(HaveOccurred())
   171  					Expect(m.DeleteStream(id)).To(Succeed())
   172  					str, err := m.AcceptStream(context.Background())
   173  					Expect(err).ToNot(HaveOccurred())
   174  					Expect(str).ToNot(BeNil())
   175  					Expect(str.StreamID()).To(Equal(id))
   176  				})
   177  
   178  				It("deletes outgoing unidirectional streams", func() {
   179  					id := ids.firstOutgoingUniStream
   180  					str, err := m.OpenUniStream()
   181  					Expect(err).ToNot(HaveOccurred())
   182  					Expect(str.StreamID()).To(Equal(id))
   183  					Expect(m.DeleteStream(id)).To(Succeed())
   184  					dstr, err := m.GetOrOpenSendStream(id)
   185  					Expect(err).ToNot(HaveOccurred())
   186  					Expect(dstr).To(BeNil())
   187  				})
   188  
   189  				It("deletes incoming unidirectional streams", func() {
   190  					id := ids.firstIncomingUniStream
   191  					str, err := m.GetOrOpenReceiveStream(id)
   192  					Expect(err).ToNot(HaveOccurred())
   193  					Expect(str.StreamID()).To(Equal(id))
   194  					Expect(m.DeleteStream(id)).To(Succeed())
   195  					dstr, err := m.GetOrOpenReceiveStream(id)
   196  					Expect(err).ToNot(HaveOccurred())
   197  					Expect(dstr).To(BeNil())
   198  				})
   199  
   200  				It("accepts unirectional streams after they have been deleted", func() {
   201  					id := ids.firstIncomingUniStream
   202  					_, err := m.GetOrOpenReceiveStream(id)
   203  					Expect(err).ToNot(HaveOccurred())
   204  					Expect(m.DeleteStream(id)).To(Succeed())
   205  					str, err := m.AcceptUniStream(context.Background())
   206  					Expect(err).ToNot(HaveOccurred())
   207  					Expect(str).ToNot(BeNil())
   208  					Expect(str.StreamID()).To(Equal(id))
   209  				})
   210  
   211  				It("errors when deleting unknown incoming unidirectional streams", func() {
   212  					id := ids.firstIncomingUniStream + 4
   213  					Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id)))
   214  				})
   215  
   216  				It("errors when deleting unknown outgoing unidirectional streams", func() {
   217  					id := ids.firstOutgoingUniStream + 4
   218  					Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id)))
   219  				})
   220  
   221  				It("errors when deleting unknown incoming bidirectional streams", func() {
   222  					id := ids.firstIncomingBidiStream + 4
   223  					Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id)))
   224  				})
   225  
   226  				It("errors when deleting unknown outgoing bidirectional streams", func() {
   227  					id := ids.firstOutgoingBidiStream + 4
   228  					Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id)))
   229  				})
   230  			})
   231  
   232  			Context("getting streams", func() {
   233  				BeforeEach(func() {
   234  					allowUnlimitedStreams()
   235  				})
   236  
   237  				Context("send streams", func() {
   238  					It("gets an outgoing bidirectional stream", func() {
   239  						// need to open the stream ourselves first
   240  						// the peer is not allowed to create a stream initiated by us
   241  						_, err := m.OpenStream()
   242  						Expect(err).ToNot(HaveOccurred())
   243  						str, err := m.GetOrOpenSendStream(ids.firstOutgoingBidiStream)
   244  						Expect(err).ToNot(HaveOccurred())
   245  						Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
   246  					})
   247  
   248  					It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
   249  						id := ids.firstOutgoingBidiStream + 5*4
   250  						_, err := m.GetOrOpenSendStream(id)
   251  						Expect(err).To(MatchError(&qerr.TransportError{
   252  							ErrorCode:    qerr.StreamStateError,
   253  							ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id),
   254  						}))
   255  					})
   256  
   257  					It("gets an outgoing unidirectional stream", func() {
   258  						// need to open the stream ourselves first
   259  						// the peer is not allowed to create a stream initiated by us
   260  						_, err := m.OpenUniStream()
   261  						Expect(err).ToNot(HaveOccurred())
   262  						str, err := m.GetOrOpenSendStream(ids.firstOutgoingUniStream)
   263  						Expect(err).ToNot(HaveOccurred())
   264  						Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
   265  					})
   266  
   267  					It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
   268  						id := ids.firstOutgoingUniStream + 5*4
   269  						_, err := m.GetOrOpenSendStream(id)
   270  						Expect(err).To(MatchError(&qerr.TransportError{
   271  							ErrorCode:    qerr.StreamStateError,
   272  							ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id),
   273  						}))
   274  					})
   275  
   276  					It("gets an incoming bidirectional stream", func() {
   277  						id := ids.firstIncomingBidiStream + 4*7
   278  						str, err := m.GetOrOpenSendStream(id)
   279  						Expect(err).ToNot(HaveOccurred())
   280  						Expect(str.StreamID()).To(Equal(id))
   281  					})
   282  
   283  					It("errors when trying to get an incoming unidirectional stream", func() {
   284  						id := ids.firstIncomingUniStream
   285  						_, err := m.GetOrOpenSendStream(id)
   286  						Expect(err).To(MatchError(&qerr.TransportError{
   287  							ErrorCode:    qerr.StreamStateError,
   288  							ErrorMessage: fmt.Sprintf("peer attempted to open send stream %d", id),
   289  						}))
   290  					})
   291  				})
   292  
   293  				Context("receive streams", func() {
   294  					It("gets an outgoing bidirectional stream", func() {
   295  						// need to open the stream ourselves first
   296  						// the peer is not allowed to create a stream initiated by us
   297  						_, err := m.OpenStream()
   298  						Expect(err).ToNot(HaveOccurred())
   299  						str, err := m.GetOrOpenReceiveStream(ids.firstOutgoingBidiStream)
   300  						Expect(err).ToNot(HaveOccurred())
   301  						Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
   302  					})
   303  
   304  					It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
   305  						id := ids.firstOutgoingBidiStream + 5*4
   306  						_, err := m.GetOrOpenReceiveStream(id)
   307  						Expect(err).To(MatchError(&qerr.TransportError{
   308  							ErrorCode:    qerr.StreamStateError,
   309  							ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id),
   310  						}))
   311  					})
   312  
   313  					It("gets an incoming bidirectional stream", func() {
   314  						id := ids.firstIncomingBidiStream + 4*7
   315  						str, err := m.GetOrOpenReceiveStream(id)
   316  						Expect(err).ToNot(HaveOccurred())
   317  						Expect(str.StreamID()).To(Equal(id))
   318  					})
   319  
   320  					It("gets an incoming unidirectional stream", func() {
   321  						id := ids.firstIncomingUniStream + 4*10
   322  						str, err := m.GetOrOpenReceiveStream(id)
   323  						Expect(err).ToNot(HaveOccurred())
   324  						Expect(str.StreamID()).To(Equal(id))
   325  					})
   326  
   327  					It("errors when trying to get an outgoing unidirectional stream", func() {
   328  						id := ids.firstOutgoingUniStream
   329  						_, err := m.GetOrOpenReceiveStream(id)
   330  						Expect(err).To(MatchError(&qerr.TransportError{
   331  							ErrorCode:    qerr.StreamStateError,
   332  							ErrorMessage: fmt.Sprintf("peer attempted to open receive stream %d", id),
   333  						}))
   334  					})
   335  				})
   336  			})
   337  
   338  			It("processes the parameter for outgoing streams", func() {
   339  				mockSender.EXPECT().queueControlFrame(gomock.Any())
   340  				_, err := m.OpenStream()
   341  				expectTooManyStreamsError(err)
   342  				m.UpdateLimits(&wire.TransportParameters{
   343  					MaxBidiStreamNum: 5,
   344  					MaxUniStreamNum:  8,
   345  				})
   346  
   347  				mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2)
   348  				// test we can only 5 bidirectional streams
   349  				for i := 0; i < 5; i++ {
   350  					str, err := m.OpenStream()
   351  					Expect(err).ToNot(HaveOccurred())
   352  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + protocol.StreamID(4*i)))
   353  				}
   354  				_, err = m.OpenStream()
   355  				expectTooManyStreamsError(err)
   356  				// test we can only 8 unidirectional streams
   357  				for i := 0; i < 8; i++ {
   358  					str, err := m.OpenUniStream()
   359  					Expect(err).ToNot(HaveOccurred())
   360  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + protocol.StreamID(4*i)))
   361  				}
   362  				_, err = m.OpenUniStream()
   363  				expectTooManyStreamsError(err)
   364  			})
   365  
   366  			if perspective == protocol.PerspectiveClient {
   367  				It("applies parameters to existing streams (needed for 0-RTT)", func() {
   368  					m.UpdateLimits(&wire.TransportParameters{
   369  						MaxBidiStreamNum: 1000,
   370  						MaxUniStreamNum:  1000,
   371  					})
   372  					flowControllers := make(map[protocol.StreamID]*mocks.MockStreamFlowController)
   373  					m.newFlowController = func(id protocol.StreamID) flowcontrol.StreamFlowController {
   374  						fc := mocks.NewMockStreamFlowController(mockCtrl)
   375  						flowControllers[id] = fc
   376  						return fc
   377  					}
   378  
   379  					str, err := m.OpenStream()
   380  					Expect(err).ToNot(HaveOccurred())
   381  					unistr, err := m.OpenUniStream()
   382  					Expect(err).ToNot(HaveOccurred())
   383  
   384  					Expect(flowControllers).To(HaveKey(str.StreamID()))
   385  					flowControllers[str.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(4321))
   386  					Expect(flowControllers).To(HaveKey(unistr.StreamID()))
   387  					flowControllers[unistr.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(1234))
   388  
   389  					m.UpdateLimits(&wire.TransportParameters{
   390  						MaxBidiStreamNum:               1000,
   391  						InitialMaxStreamDataUni:        1234,
   392  						MaxUniStreamNum:                1000,
   393  						InitialMaxStreamDataBidiRemote: 4321,
   394  					})
   395  				})
   396  			}
   397  
   398  			Context("handling MAX_STREAMS frames", func() {
   399  				BeforeEach(func() {
   400  					mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
   401  				})
   402  
   403  				It("processes IDs for outgoing bidirectional streams", func() {
   404  					_, err := m.OpenStream()
   405  					expectTooManyStreamsError(err)
   406  					m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
   407  						Type:         protocol.StreamTypeBidi,
   408  						MaxStreamNum: 1,
   409  					})
   410  					str, err := m.OpenStream()
   411  					Expect(err).ToNot(HaveOccurred())
   412  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
   413  					_, err = m.OpenStream()
   414  					expectTooManyStreamsError(err)
   415  				})
   416  
   417  				It("processes IDs for outgoing unidirectional streams", func() {
   418  					_, err := m.OpenUniStream()
   419  					expectTooManyStreamsError(err)
   420  					m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
   421  						Type:         protocol.StreamTypeUni,
   422  						MaxStreamNum: 1,
   423  					})
   424  					str, err := m.OpenUniStream()
   425  					Expect(err).ToNot(HaveOccurred())
   426  					Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
   427  					_, err = m.OpenUniStream()
   428  					expectTooManyStreamsError(err)
   429  				})
   430  			})
   431  
   432  			Context("sending MAX_STREAMS frames", func() {
   433  				It("sends a MAX_STREAMS frame for bidirectional streams", func() {
   434  					_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
   435  					Expect(err).ToNot(HaveOccurred())
   436  					_, err = m.AcceptStream(context.Background())
   437  					Expect(err).ToNot(HaveOccurred())
   438  					mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
   439  						Type:         protocol.StreamTypeBidi,
   440  						MaxStreamNum: MaxBidiStreamNum + 1,
   441  					})
   442  					Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed())
   443  				})
   444  
   445  				It("sends a MAX_STREAMS frame for unidirectional streams", func() {
   446  					_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
   447  					Expect(err).ToNot(HaveOccurred())
   448  					_, err = m.AcceptUniStream(context.Background())
   449  					Expect(err).ToNot(HaveOccurred())
   450  					mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
   451  						Type:         protocol.StreamTypeUni,
   452  						MaxStreamNum: MaxUniStreamNum + 1,
   453  					})
   454  					Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed())
   455  				})
   456  			})
   457  
   458  			It("closes", func() {
   459  				testErr := errors.New("test error")
   460  				m.CloseWithError(testErr)
   461  				_, err := m.OpenStream()
   462  				Expect(err).To(HaveOccurred())
   463  				Expect(err.Error()).To(Equal(testErr.Error()))
   464  				_, err = m.OpenUniStream()
   465  				Expect(err).To(HaveOccurred())
   466  				Expect(err.Error()).To(Equal(testErr.Error()))
   467  				_, err = m.AcceptStream(context.Background())
   468  				Expect(err).To(HaveOccurred())
   469  				Expect(err.Error()).To(Equal(testErr.Error()))
   470  				_, err = m.AcceptUniStream(context.Background())
   471  				Expect(err).To(HaveOccurred())
   472  				Expect(err.Error()).To(Equal(testErr.Error()))
   473  			})
   474  
   475  			if perspective == protocol.PerspectiveClient {
   476  				It("resets for 0-RTT", func() {
   477  					mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
   478  					m.ResetFor0RTT()
   479  					// make sure that calls to open / accept streams fail
   480  					_, err := m.OpenStream()
   481  					Expect(err).To(MatchError(Err0RTTRejected))
   482  					_, err = m.AcceptStream(context.Background())
   483  					Expect(err).To(MatchError(Err0RTTRejected))
   484  					// make sure that we can still get new streams, as the server might be sending us data
   485  					str, err := m.GetOrOpenReceiveStream(3)
   486  					Expect(err).ToNot(HaveOccurred())
   487  					Expect(str).ToNot(BeNil())
   488  
   489  					// now switch to using the new streams map
   490  					m.UseResetMaps()
   491  					_, err = m.OpenStream()
   492  					Expect(err).To(HaveOccurred())
   493  					Expect(err.Error()).To(ContainSubstring("too many open streams"))
   494  				})
   495  			}
   496  		})
   497  	}
   498  })