github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/streams_map_test.go (about) 1 package quic 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 9 "github.com/golang/mock/gomock" 10 11 "github.com/mikelsr/quic-go/internal/flowcontrol" 12 "github.com/mikelsr/quic-go/internal/mocks" 13 "github.com/mikelsr/quic-go/internal/protocol" 14 "github.com/mikelsr/quic-go/internal/qerr" 15 "github.com/mikelsr/quic-go/internal/wire" 16 17 . "github.com/onsi/ginkgo/v2" 18 . "github.com/onsi/gomega" 19 ) 20 21 func (e streamError) TestError() error { 22 nums := make([]interface{}, len(e.nums)) 23 for i, num := range e.nums { 24 nums[i] = num 25 } 26 return fmt.Errorf(e.message, nums...) 27 } 28 29 type streamMapping struct { 30 firstIncomingBidiStream protocol.StreamID 31 firstIncomingUniStream protocol.StreamID 32 firstOutgoingBidiStream protocol.StreamID 33 firstOutgoingUniStream protocol.StreamID 34 } 35 36 func expectTooManyStreamsError(err error) { 37 ExpectWithOffset(1, err).To(HaveOccurred()) 38 ExpectWithOffset(1, err.Error()).To(Equal(errTooManyOpenStreams.Error())) 39 nerr, ok := err.(net.Error) 40 ExpectWithOffset(1, ok).To(BeTrue()) 41 ExpectWithOffset(1, nerr.Timeout()).To(BeFalse()) 42 } 43 44 var _ = Describe("Streams Map", func() { 45 newFlowController := func(protocol.StreamID) flowcontrol.StreamFlowController { 46 return mocks.NewMockStreamFlowController(mockCtrl) 47 } 48 49 serverStreamMapping := streamMapping{ 50 firstIncomingBidiStream: 0, 51 firstOutgoingBidiStream: 1, 52 firstIncomingUniStream: 2, 53 firstOutgoingUniStream: 3, 54 } 55 clientStreamMapping := streamMapping{ 56 firstIncomingBidiStream: 1, 57 firstOutgoingBidiStream: 0, 58 firstIncomingUniStream: 3, 59 firstOutgoingUniStream: 2, 60 } 61 62 for _, p := range []protocol.Perspective{protocol.PerspectiveServer, protocol.PerspectiveClient} { 63 perspective := p 64 var ids streamMapping 65 if perspective == protocol.PerspectiveClient { 66 ids = clientStreamMapping 67 } else { 68 ids = serverStreamMapping 69 } 70 71 Context(perspective.String(), func() { 72 var ( 73 m *streamsMap 74 mockSender *MockStreamSender 75 ) 76 77 const ( 78 MaxBidiStreamNum = 111 79 MaxUniStreamNum = 222 80 ) 81 82 allowUnlimitedStreams := func() { 83 m.UpdateLimits(&wire.TransportParameters{ 84 MaxBidiStreamNum: protocol.MaxStreamCount, 85 MaxUniStreamNum: protocol.MaxStreamCount, 86 }) 87 } 88 89 BeforeEach(func() { 90 mockSender = NewMockStreamSender(mockCtrl) 91 m = newStreamsMap(mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective).(*streamsMap) 92 }) 93 94 Context("opening", func() { 95 It("opens bidirectional streams", func() { 96 allowUnlimitedStreams() 97 str, err := m.OpenStream() 98 Expect(err).ToNot(HaveOccurred()) 99 Expect(str).To(BeAssignableToTypeOf(&stream{})) 100 Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) 101 str, err = m.OpenStream() 102 Expect(err).ToNot(HaveOccurred()) 103 Expect(str).To(BeAssignableToTypeOf(&stream{})) 104 Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + 4)) 105 }) 106 107 It("opens unidirectional streams", func() { 108 allowUnlimitedStreams() 109 str, err := m.OpenUniStream() 110 Expect(err).ToNot(HaveOccurred()) 111 Expect(str).To(BeAssignableToTypeOf(&sendStream{})) 112 Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) 113 str, err = m.OpenUniStream() 114 Expect(err).ToNot(HaveOccurred()) 115 Expect(str).To(BeAssignableToTypeOf(&sendStream{})) 116 Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + 4)) 117 }) 118 }) 119 120 Context("accepting", func() { 121 It("accepts bidirectional streams", func() { 122 _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) 123 Expect(err).ToNot(HaveOccurred()) 124 str, err := m.AcceptStream(context.Background()) 125 Expect(err).ToNot(HaveOccurred()) 126 Expect(str).To(BeAssignableToTypeOf(&stream{})) 127 Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream)) 128 }) 129 130 It("accepts unidirectional streams", func() { 131 _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) 132 Expect(err).ToNot(HaveOccurred()) 133 str, err := m.AcceptUniStream(context.Background()) 134 Expect(err).ToNot(HaveOccurred()) 135 Expect(str).To(BeAssignableToTypeOf(&receiveStream{})) 136 Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream)) 137 }) 138 }) 139 140 Context("deleting", func() { 141 BeforeEach(func() { 142 mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() 143 allowUnlimitedStreams() 144 }) 145 146 It("deletes outgoing bidirectional streams", func() { 147 id := ids.firstOutgoingBidiStream 148 str, err := m.OpenStream() 149 Expect(err).ToNot(HaveOccurred()) 150 Expect(str.StreamID()).To(Equal(id)) 151 Expect(m.DeleteStream(id)).To(Succeed()) 152 dstr, err := m.GetOrOpenSendStream(id) 153 Expect(err).ToNot(HaveOccurred()) 154 Expect(dstr).To(BeNil()) 155 }) 156 157 It("deletes incoming bidirectional streams", func() { 158 id := ids.firstIncomingBidiStream 159 str, err := m.GetOrOpenReceiveStream(id) 160 Expect(err).ToNot(HaveOccurred()) 161 Expect(str.StreamID()).To(Equal(id)) 162 Expect(m.DeleteStream(id)).To(Succeed()) 163 dstr, err := m.GetOrOpenReceiveStream(id) 164 Expect(err).ToNot(HaveOccurred()) 165 Expect(dstr).To(BeNil()) 166 }) 167 168 It("accepts bidirectional streams after they have been deleted", func() { 169 id := ids.firstIncomingBidiStream 170 _, err := m.GetOrOpenReceiveStream(id) 171 Expect(err).ToNot(HaveOccurred()) 172 Expect(m.DeleteStream(id)).To(Succeed()) 173 str, err := m.AcceptStream(context.Background()) 174 Expect(err).ToNot(HaveOccurred()) 175 Expect(str).ToNot(BeNil()) 176 Expect(str.StreamID()).To(Equal(id)) 177 }) 178 179 It("deletes outgoing unidirectional streams", func() { 180 id := ids.firstOutgoingUniStream 181 str, err := m.OpenUniStream() 182 Expect(err).ToNot(HaveOccurred()) 183 Expect(str.StreamID()).To(Equal(id)) 184 Expect(m.DeleteStream(id)).To(Succeed()) 185 dstr, err := m.GetOrOpenSendStream(id) 186 Expect(err).ToNot(HaveOccurred()) 187 Expect(dstr).To(BeNil()) 188 }) 189 190 It("deletes incoming unidirectional streams", func() { 191 id := ids.firstIncomingUniStream 192 str, err := m.GetOrOpenReceiveStream(id) 193 Expect(err).ToNot(HaveOccurred()) 194 Expect(str.StreamID()).To(Equal(id)) 195 Expect(m.DeleteStream(id)).To(Succeed()) 196 dstr, err := m.GetOrOpenReceiveStream(id) 197 Expect(err).ToNot(HaveOccurred()) 198 Expect(dstr).To(BeNil()) 199 }) 200 201 It("accepts unirectional streams after they have been deleted", func() { 202 id := ids.firstIncomingUniStream 203 _, err := m.GetOrOpenReceiveStream(id) 204 Expect(err).ToNot(HaveOccurred()) 205 Expect(m.DeleteStream(id)).To(Succeed()) 206 str, err := m.AcceptUniStream(context.Background()) 207 Expect(err).ToNot(HaveOccurred()) 208 Expect(str).ToNot(BeNil()) 209 Expect(str.StreamID()).To(Equal(id)) 210 }) 211 212 It("errors when deleting unknown incoming unidirectional streams", func() { 213 id := ids.firstIncomingUniStream + 4 214 Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id))) 215 }) 216 217 It("errors when deleting unknown outgoing unidirectional streams", func() { 218 id := ids.firstOutgoingUniStream + 4 219 Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id))) 220 }) 221 222 It("errors when deleting unknown incoming bidirectional streams", func() { 223 id := ids.firstIncomingBidiStream + 4 224 Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id))) 225 }) 226 227 It("errors when deleting unknown outgoing bidirectional streams", func() { 228 id := ids.firstOutgoingBidiStream + 4 229 Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id))) 230 }) 231 }) 232 233 Context("getting streams", func() { 234 BeforeEach(func() { 235 allowUnlimitedStreams() 236 }) 237 238 Context("send streams", func() { 239 It("gets an outgoing bidirectional stream", func() { 240 // need to open the stream ourselves first 241 // the peer is not allowed to create a stream initiated by us 242 _, err := m.OpenStream() 243 Expect(err).ToNot(HaveOccurred()) 244 str, err := m.GetOrOpenSendStream(ids.firstOutgoingBidiStream) 245 Expect(err).ToNot(HaveOccurred()) 246 Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) 247 }) 248 249 It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { 250 id := ids.firstOutgoingBidiStream + 5*4 251 _, err := m.GetOrOpenSendStream(id) 252 Expect(err).To(MatchError(&qerr.TransportError{ 253 ErrorCode: qerr.StreamStateError, 254 ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), 255 })) 256 }) 257 258 It("gets an outgoing unidirectional stream", func() { 259 // need to open the stream ourselves first 260 // the peer is not allowed to create a stream initiated by us 261 _, err := m.OpenUniStream() 262 Expect(err).ToNot(HaveOccurred()) 263 str, err := m.GetOrOpenSendStream(ids.firstOutgoingUniStream) 264 Expect(err).ToNot(HaveOccurred()) 265 Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) 266 }) 267 268 It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { 269 id := ids.firstOutgoingUniStream + 5*4 270 _, err := m.GetOrOpenSendStream(id) 271 Expect(err).To(MatchError(&qerr.TransportError{ 272 ErrorCode: qerr.StreamStateError, 273 ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), 274 })) 275 }) 276 277 It("gets an incoming bidirectional stream", func() { 278 id := ids.firstIncomingBidiStream + 4*7 279 str, err := m.GetOrOpenSendStream(id) 280 Expect(err).ToNot(HaveOccurred()) 281 Expect(str.StreamID()).To(Equal(id)) 282 }) 283 284 It("errors when trying to get an incoming unidirectional stream", func() { 285 id := ids.firstIncomingUniStream 286 _, err := m.GetOrOpenSendStream(id) 287 Expect(err).To(MatchError(&qerr.TransportError{ 288 ErrorCode: qerr.StreamStateError, 289 ErrorMessage: fmt.Sprintf("peer attempted to open send stream %d", id), 290 })) 291 }) 292 }) 293 294 Context("receive streams", func() { 295 It("gets an outgoing bidirectional stream", func() { 296 // need to open the stream ourselves first 297 // the peer is not allowed to create a stream initiated by us 298 _, err := m.OpenStream() 299 Expect(err).ToNot(HaveOccurred()) 300 str, err := m.GetOrOpenReceiveStream(ids.firstOutgoingBidiStream) 301 Expect(err).ToNot(HaveOccurred()) 302 Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) 303 }) 304 305 It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { 306 id := ids.firstOutgoingBidiStream + 5*4 307 _, err := m.GetOrOpenReceiveStream(id) 308 Expect(err).To(MatchError(&qerr.TransportError{ 309 ErrorCode: qerr.StreamStateError, 310 ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), 311 })) 312 }) 313 314 It("gets an incoming bidirectional stream", func() { 315 id := ids.firstIncomingBidiStream + 4*7 316 str, err := m.GetOrOpenReceiveStream(id) 317 Expect(err).ToNot(HaveOccurred()) 318 Expect(str.StreamID()).To(Equal(id)) 319 }) 320 321 It("gets an incoming unidirectional stream", func() { 322 id := ids.firstIncomingUniStream + 4*10 323 str, err := m.GetOrOpenReceiveStream(id) 324 Expect(err).ToNot(HaveOccurred()) 325 Expect(str.StreamID()).To(Equal(id)) 326 }) 327 328 It("errors when trying to get an outgoing unidirectional stream", func() { 329 id := ids.firstOutgoingUniStream 330 _, err := m.GetOrOpenReceiveStream(id) 331 Expect(err).To(MatchError(&qerr.TransportError{ 332 ErrorCode: qerr.StreamStateError, 333 ErrorMessage: fmt.Sprintf("peer attempted to open receive stream %d", id), 334 })) 335 }) 336 }) 337 }) 338 339 It("processes the parameter for outgoing streams", func() { 340 mockSender.EXPECT().queueControlFrame(gomock.Any()) 341 _, err := m.OpenStream() 342 expectTooManyStreamsError(err) 343 m.UpdateLimits(&wire.TransportParameters{ 344 MaxBidiStreamNum: 5, 345 MaxUniStreamNum: 8, 346 }) 347 348 mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2) 349 // test we can only 5 bidirectional streams 350 for i := 0; i < 5; i++ { 351 str, err := m.OpenStream() 352 Expect(err).ToNot(HaveOccurred()) 353 Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + protocol.StreamID(4*i))) 354 } 355 _, err = m.OpenStream() 356 expectTooManyStreamsError(err) 357 // test we can only 8 unidirectional streams 358 for i := 0; i < 8; i++ { 359 str, err := m.OpenUniStream() 360 Expect(err).ToNot(HaveOccurred()) 361 Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + protocol.StreamID(4*i))) 362 } 363 _, err = m.OpenUniStream() 364 expectTooManyStreamsError(err) 365 }) 366 367 if perspective == protocol.PerspectiveClient { 368 It("applies parameters to existing streams (needed for 0-RTT)", func() { 369 m.UpdateLimits(&wire.TransportParameters{ 370 MaxBidiStreamNum: 1000, 371 MaxUniStreamNum: 1000, 372 }) 373 flowControllers := make(map[protocol.StreamID]*mocks.MockStreamFlowController) 374 m.newFlowController = func(id protocol.StreamID) flowcontrol.StreamFlowController { 375 fc := mocks.NewMockStreamFlowController(mockCtrl) 376 flowControllers[id] = fc 377 return fc 378 } 379 380 str, err := m.OpenStream() 381 Expect(err).ToNot(HaveOccurred()) 382 unistr, err := m.OpenUniStream() 383 Expect(err).ToNot(HaveOccurred()) 384 385 Expect(flowControllers).To(HaveKey(str.StreamID())) 386 flowControllers[str.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(4321)) 387 Expect(flowControllers).To(HaveKey(unistr.StreamID())) 388 flowControllers[unistr.StreamID()].EXPECT().UpdateSendWindow(protocol.ByteCount(1234)) 389 390 m.UpdateLimits(&wire.TransportParameters{ 391 MaxBidiStreamNum: 1000, 392 InitialMaxStreamDataUni: 1234, 393 MaxUniStreamNum: 1000, 394 InitialMaxStreamDataBidiRemote: 4321, 395 }) 396 }) 397 } 398 399 Context("handling MAX_STREAMS frames", func() { 400 BeforeEach(func() { 401 mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() 402 }) 403 404 It("processes IDs for outgoing bidirectional streams", func() { 405 _, err := m.OpenStream() 406 expectTooManyStreamsError(err) 407 m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ 408 Type: protocol.StreamTypeBidi, 409 MaxStreamNum: 1, 410 }) 411 str, err := m.OpenStream() 412 Expect(err).ToNot(HaveOccurred()) 413 Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) 414 _, err = m.OpenStream() 415 expectTooManyStreamsError(err) 416 }) 417 418 It("processes IDs for outgoing unidirectional streams", func() { 419 _, err := m.OpenUniStream() 420 expectTooManyStreamsError(err) 421 m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ 422 Type: protocol.StreamTypeUni, 423 MaxStreamNum: 1, 424 }) 425 str, err := m.OpenUniStream() 426 Expect(err).ToNot(HaveOccurred()) 427 Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) 428 _, err = m.OpenUniStream() 429 expectTooManyStreamsError(err) 430 }) 431 }) 432 433 Context("sending MAX_STREAMS frames", func() { 434 It("sends a MAX_STREAMS frame for bidirectional streams", func() { 435 _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) 436 Expect(err).ToNot(HaveOccurred()) 437 _, err = m.AcceptStream(context.Background()) 438 Expect(err).ToNot(HaveOccurred()) 439 mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ 440 Type: protocol.StreamTypeBidi, 441 MaxStreamNum: MaxBidiStreamNum + 1, 442 }) 443 Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed()) 444 }) 445 446 It("sends a MAX_STREAMS frame for unidirectional streams", func() { 447 _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) 448 Expect(err).ToNot(HaveOccurred()) 449 _, err = m.AcceptUniStream(context.Background()) 450 Expect(err).ToNot(HaveOccurred()) 451 mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ 452 Type: protocol.StreamTypeUni, 453 MaxStreamNum: MaxUniStreamNum + 1, 454 }) 455 Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed()) 456 }) 457 }) 458 459 It("closes", func() { 460 testErr := errors.New("test error") 461 m.CloseWithError(testErr) 462 _, err := m.OpenStream() 463 Expect(err).To(HaveOccurred()) 464 Expect(err.Error()).To(Equal(testErr.Error())) 465 _, err = m.OpenUniStream() 466 Expect(err).To(HaveOccurred()) 467 Expect(err.Error()).To(Equal(testErr.Error())) 468 _, err = m.AcceptStream(context.Background()) 469 Expect(err).To(HaveOccurred()) 470 Expect(err.Error()).To(Equal(testErr.Error())) 471 _, err = m.AcceptUniStream(context.Background()) 472 Expect(err).To(HaveOccurred()) 473 Expect(err.Error()).To(Equal(testErr.Error())) 474 }) 475 476 if perspective == protocol.PerspectiveClient { 477 It("resets for 0-RTT", func() { 478 mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() 479 m.ResetFor0RTT() 480 // make sure that calls to open / accept streams fail 481 _, err := m.OpenStream() 482 Expect(err).To(MatchError(Err0RTTRejected)) 483 _, err = m.AcceptStream(context.Background()) 484 Expect(err).To(MatchError(Err0RTTRejected)) 485 // make sure that we can still get new streams, as the server might be sending us data 486 str, err := m.GetOrOpenReceiveStream(3) 487 Expect(err).ToNot(HaveOccurred()) 488 Expect(str).ToNot(BeNil()) 489 490 // now switch to using the new streams map 491 m.UseResetMaps() 492 _, err = m.OpenStream() 493 Expect(err).To(HaveOccurred()) 494 Expect(err.Error()).To(ContainSubstring("too many open streams")) 495 }) 496 } 497 }) 498 } 499 })