github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/streams_map_incoming_test.go (about) 1 package quic 2 3 import ( 4 "context" 5 "errors" 6 "time" 7 8 "golang.org/x/exp/rand" 9 10 "github.com/daeuniverse/quic-go/internal/protocol" 11 "github.com/daeuniverse/quic-go/internal/wire" 12 13 . "github.com/onsi/ginkgo/v2" 14 . "github.com/onsi/gomega" 15 "go.uber.org/mock/gomock" 16 ) 17 18 type mockGenericStream struct { 19 num protocol.StreamNum 20 21 closed bool 22 closeErr error 23 sendWindow protocol.ByteCount 24 } 25 26 func (s *mockGenericStream) closeForShutdown(err error) { 27 s.closed = true 28 s.closeErr = err 29 } 30 31 func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) { 32 s.sendWindow = limit 33 } 34 35 var _ = Describe("Streams Map (incoming)", func() { 36 var ( 37 m *incomingStreamsMap[*mockGenericStream] 38 newItemCounter int 39 mockSender *MockStreamSender 40 maxNumStreams uint64 41 ) 42 streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeUni}[rand.Intn(2)] 43 44 // check that the frame can be serialized and deserialized 45 checkFrameSerialization := func(f wire.Frame) { 46 b, err := f.Append(nil, protocol.Version1) 47 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 48 _, frame, err := wire.NewFrameParser(false).ParseNext(b, protocol.Encryption1RTT, protocol.Version1) 49 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 50 Expect(f).To(Equal(frame)) 51 } 52 53 BeforeEach(func() { maxNumStreams = 5 }) 54 55 JustBeforeEach(func() { 56 newItemCounter = 0 57 mockSender = NewMockStreamSender(mockCtrl) 58 m = newIncomingStreamsMap( 59 streamType, 60 func(num protocol.StreamNum) *mockGenericStream { 61 newItemCounter++ 62 return &mockGenericStream{num: num} 63 }, 64 maxNumStreams, 65 mockSender.queueControlFrame, 66 ) 67 }) 68 69 It("opens all streams up to the id on GetOrOpenStream", func() { 70 _, err := m.GetOrOpenStream(4) 71 Expect(err).ToNot(HaveOccurred()) 72 Expect(newItemCounter).To(Equal(4)) 73 }) 74 75 It("starts opening streams at the right position", func() { 76 // like the test above, but with 2 calls to GetOrOpenStream 77 _, err := m.GetOrOpenStream(2) 78 Expect(err).ToNot(HaveOccurred()) 79 Expect(newItemCounter).To(Equal(2)) 80 _, err = m.GetOrOpenStream(5) 81 Expect(err).ToNot(HaveOccurred()) 82 Expect(newItemCounter).To(Equal(5)) 83 }) 84 85 It("accepts streams in the right order", func() { 86 _, err := m.GetOrOpenStream(2) // open streams 1 and 2 87 Expect(err).ToNot(HaveOccurred()) 88 str, err := m.AcceptStream(context.Background()) 89 Expect(err).ToNot(HaveOccurred()) 90 Expect(str.num).To(Equal(protocol.StreamNum(1))) 91 str, err = m.AcceptStream(context.Background()) 92 Expect(err).ToNot(HaveOccurred()) 93 Expect(str.num).To(Equal(protocol.StreamNum(2))) 94 }) 95 96 It("allows opening the maximum stream ID", func() { 97 str, err := m.GetOrOpenStream(1) 98 Expect(err).ToNot(HaveOccurred()) 99 Expect(str.num).To(Equal(protocol.StreamNum(1))) 100 }) 101 102 It("errors when trying to get a stream ID higher than the maximum", func() { 103 _, err := m.GetOrOpenStream(6) 104 Expect(err).To(HaveOccurred()) 105 Expect(err.(streamError).TestError()).To(MatchError("peer tried to open stream 6 (current limit: 5)")) 106 }) 107 108 It("blocks AcceptStream until a new stream is available", func() { 109 strChan := make(chan *mockGenericStream) 110 go func() { 111 defer GinkgoRecover() 112 str, err := m.AcceptStream(context.Background()) 113 Expect(err).ToNot(HaveOccurred()) 114 strChan <- str 115 }() 116 Consistently(strChan).ShouldNot(Receive()) 117 str, err := m.GetOrOpenStream(1) 118 Expect(err).ToNot(HaveOccurred()) 119 Expect(str.num).To(Equal(protocol.StreamNum(1))) 120 var acceptedStr *mockGenericStream 121 Eventually(strChan).Should(Receive(&acceptedStr)) 122 Expect(acceptedStr.num).To(Equal(protocol.StreamNum(1))) 123 }) 124 125 It("unblocks AcceptStream when the context is canceled", func() { 126 ctx, cancel := context.WithCancel(context.Background()) 127 done := make(chan struct{}) 128 go func() { 129 defer GinkgoRecover() 130 _, err := m.AcceptStream(ctx) 131 Expect(err).To(MatchError("context canceled")) 132 close(done) 133 }() 134 Consistently(done).ShouldNot(BeClosed()) 135 cancel() 136 Eventually(done).Should(BeClosed()) 137 }) 138 139 It("unblocks AcceptStream when it is closed", func() { 140 testErr := errors.New("test error") 141 done := make(chan struct{}) 142 go func() { 143 defer GinkgoRecover() 144 _, err := m.AcceptStream(context.Background()) 145 Expect(err).To(MatchError(testErr)) 146 close(done) 147 }() 148 Consistently(done).ShouldNot(BeClosed()) 149 m.CloseWithError(testErr) 150 Eventually(done).Should(BeClosed()) 151 }) 152 153 It("errors AcceptStream immediately if it is closed", func() { 154 testErr := errors.New("test error") 155 m.CloseWithError(testErr) 156 _, err := m.AcceptStream(context.Background()) 157 Expect(err).To(MatchError(testErr)) 158 }) 159 160 It("closes all streams when CloseWithError is called", func() { 161 str1, err := m.GetOrOpenStream(1) 162 Expect(err).ToNot(HaveOccurred()) 163 str2, err := m.GetOrOpenStream(3) 164 Expect(err).ToNot(HaveOccurred()) 165 testErr := errors.New("test err") 166 m.CloseWithError(testErr) 167 Expect(str1.closed).To(BeTrue()) 168 Expect(str1.closeErr).To(MatchError(testErr)) 169 Expect(str2.closed).To(BeTrue()) 170 Expect(str2.closeErr).To(MatchError(testErr)) 171 }) 172 173 It("deletes streams", func() { 174 mockSender.EXPECT().queueControlFrame(gomock.Any()) 175 _, err := m.GetOrOpenStream(1) 176 Expect(err).ToNot(HaveOccurred()) 177 str, err := m.AcceptStream(context.Background()) 178 Expect(err).ToNot(HaveOccurred()) 179 Expect(str.num).To(Equal(protocol.StreamNum(1))) 180 Expect(m.DeleteStream(1)).To(Succeed()) 181 str, err = m.GetOrOpenStream(1) 182 Expect(err).ToNot(HaveOccurred()) 183 Expect(str).To(BeNil()) 184 }) 185 186 It("waits until a stream is accepted before actually deleting it", func() { 187 _, err := m.GetOrOpenStream(2) 188 Expect(err).ToNot(HaveOccurred()) 189 Expect(m.DeleteStream(2)).To(Succeed()) 190 str, err := m.AcceptStream(context.Background()) 191 Expect(err).ToNot(HaveOccurred()) 192 Expect(str.num).To(Equal(protocol.StreamNum(1))) 193 // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued 194 mockSender.EXPECT().queueControlFrame(gomock.Any()) 195 str, err = m.AcceptStream(context.Background()) 196 Expect(err).ToNot(HaveOccurred()) 197 Expect(str.num).To(Equal(protocol.StreamNum(2))) 198 }) 199 200 It("doesn't return a stream queued for deleting from GetOrOpenStream", func() { 201 str, err := m.GetOrOpenStream(1) 202 Expect(err).ToNot(HaveOccurred()) 203 Expect(str).ToNot(BeNil()) 204 Expect(m.DeleteStream(1)).To(Succeed()) 205 str, err = m.GetOrOpenStream(1) 206 Expect(err).ToNot(HaveOccurred()) 207 Expect(str).To(BeNil()) 208 // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued 209 mockSender.EXPECT().queueControlFrame(gomock.Any()) 210 str, err = m.AcceptStream(context.Background()) 211 Expect(err).ToNot(HaveOccurred()) 212 Expect(str).ToNot(BeNil()) 213 }) 214 215 It("errors when deleting a non-existing stream", func() { 216 err := m.DeleteStream(1337) 217 Expect(err).To(HaveOccurred()) 218 Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown incoming stream 1337")) 219 }) 220 221 It("sends MAX_STREAMS frames when streams are deleted", func() { 222 // open a bunch of streams 223 _, err := m.GetOrOpenStream(5) 224 Expect(err).ToNot(HaveOccurred()) 225 // accept all streams 226 for i := 0; i < 5; i++ { 227 _, err := m.AcceptStream(context.Background()) 228 Expect(err).ToNot(HaveOccurred()) 229 } 230 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 231 msf := f.(*wire.MaxStreamsFrame) 232 Expect(msf.Type).To(BeEquivalentTo(streamType)) 233 Expect(msf.MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1))) 234 checkFrameSerialization(f) 235 }) 236 Expect(m.DeleteStream(3)).To(Succeed()) 237 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 238 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2))) 239 checkFrameSerialization(f) 240 }) 241 Expect(m.DeleteStream(4)).To(Succeed()) 242 }) 243 244 Context("using high stream limits", func() { 245 BeforeEach(func() { maxNumStreams = uint64(protocol.MaxStreamCount) - 2 }) 246 247 It("doesn't send MAX_STREAMS frames if they would overflow 2^60 (the maximum stream count)", func() { 248 // open a bunch of streams 249 _, err := m.GetOrOpenStream(5) 250 Expect(err).ToNot(HaveOccurred()) 251 // accept all streams 252 for i := 0; i < 5; i++ { 253 _, err := m.AcceptStream(context.Background()) 254 Expect(err).ToNot(HaveOccurred()) 255 } 256 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 257 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1)) 258 checkFrameSerialization(f) 259 }) 260 Expect(m.DeleteStream(4)).To(Succeed()) 261 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 262 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount)) 263 checkFrameSerialization(f) 264 }) 265 Expect(m.DeleteStream(3)).To(Succeed()) 266 // at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent 267 Expect(m.DeleteStream(2)).To(Succeed()) 268 Expect(m.DeleteStream(1)).To(Succeed()) 269 }) 270 }) 271 272 Context("randomized tests", func() { 273 const num = 1000 274 275 BeforeEach(func() { maxNumStreams = num }) 276 277 It("opens and accepts streams", func() { 278 rand.Seed(uint64(GinkgoRandomSeed())) 279 ids := make([]protocol.StreamNum, num) 280 for i := 0; i < num; i++ { 281 ids[i] = protocol.StreamNum(i + 1) 282 } 283 rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] }) 284 285 const timeout = 5 * time.Second 286 done := make(chan struct{}, 2) 287 go func() { 288 defer GinkgoRecover() 289 ctx, cancel := context.WithTimeout(context.Background(), timeout) 290 defer cancel() 291 for i := 0; i < num; i++ { 292 _, err := m.AcceptStream(ctx) 293 Expect(err).ToNot(HaveOccurred()) 294 } 295 done <- struct{}{} 296 }() 297 298 go func() { 299 defer GinkgoRecover() 300 for i := 0; i < num; i++ { 301 _, err := m.GetOrOpenStream(ids[i]) 302 Expect(err).ToNot(HaveOccurred()) 303 } 304 done <- struct{}{} 305 }() 306 307 Eventually(done, timeout*3/2).Should(Receive()) 308 Eventually(done, timeout*3/2).Should(Receive()) 309 }) 310 }) 311 })