github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/streams_map_incoming_test.go (about) 1 package quic 2 3 import ( 4 "context" 5 "errors" 6 "math/rand" 7 "time" 8 9 "github.com/mikelsr/quic-go/internal/protocol" 10 "github.com/mikelsr/quic-go/internal/wire" 11 12 "github.com/golang/mock/gomock" 13 . "github.com/onsi/ginkgo/v2" 14 . "github.com/onsi/gomega" 15 ) 16 17 type mockGenericStream struct { 18 num protocol.StreamNum 19 20 closed bool 21 closeErr error 22 sendWindow protocol.ByteCount 23 } 24 25 func (s *mockGenericStream) closeForShutdown(err error) { 26 s.closed = true 27 s.closeErr = err 28 } 29 30 func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) { 31 s.sendWindow = limit 32 } 33 34 var _ = Describe("Streams Map (incoming)", func() { 35 var ( 36 m *incomingStreamsMap[*mockGenericStream] 37 newItemCounter int 38 mockSender *MockStreamSender 39 maxNumStreams uint64 40 ) 41 streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeUni}[rand.Intn(2)] 42 43 // check that the frame can be serialized and deserialized 44 checkFrameSerialization := func(f wire.Frame) { 45 b, err := f.Append(nil, protocol.Version1) 46 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 47 _, frame, err := wire.NewFrameParser(false).ParseNext(b, protocol.Encryption1RTT, protocol.Version1) 48 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 49 Expect(f).To(Equal(frame)) 50 } 51 52 BeforeEach(func() { maxNumStreams = 5 }) 53 54 JustBeforeEach(func() { 55 newItemCounter = 0 56 mockSender = NewMockStreamSender(mockCtrl) 57 m = newIncomingStreamsMap( 58 streamType, 59 func(num protocol.StreamNum) *mockGenericStream { 60 newItemCounter++ 61 return &mockGenericStream{num: num} 62 }, 63 maxNumStreams, 64 mockSender.queueControlFrame, 65 ) 66 }) 67 68 It("opens all streams up to the id on GetOrOpenStream", func() { 69 _, err := m.GetOrOpenStream(4) 70 Expect(err).ToNot(HaveOccurred()) 71 Expect(newItemCounter).To(Equal(4)) 72 }) 73 74 It("starts opening streams at the right position", func() { 75 // like the test above, but with 2 calls to GetOrOpenStream 76 _, err := m.GetOrOpenStream(2) 77 Expect(err).ToNot(HaveOccurred()) 78 Expect(newItemCounter).To(Equal(2)) 79 _, err = m.GetOrOpenStream(5) 80 Expect(err).ToNot(HaveOccurred()) 81 Expect(newItemCounter).To(Equal(5)) 82 }) 83 84 It("accepts streams in the right order", func() { 85 _, err := m.GetOrOpenStream(2) // open streams 1 and 2 86 Expect(err).ToNot(HaveOccurred()) 87 str, err := m.AcceptStream(context.Background()) 88 Expect(err).ToNot(HaveOccurred()) 89 Expect(str.num).To(Equal(protocol.StreamNum(1))) 90 str, err = m.AcceptStream(context.Background()) 91 Expect(err).ToNot(HaveOccurred()) 92 Expect(str.num).To(Equal(protocol.StreamNum(2))) 93 }) 94 95 It("allows opening the maximum stream ID", func() { 96 str, err := m.GetOrOpenStream(1) 97 Expect(err).ToNot(HaveOccurred()) 98 Expect(str.num).To(Equal(protocol.StreamNum(1))) 99 }) 100 101 It("errors when trying to get a stream ID higher than the maximum", func() { 102 _, err := m.GetOrOpenStream(6) 103 Expect(err).To(HaveOccurred()) 104 Expect(err.(streamError).TestError()).To(MatchError("peer tried to open stream 6 (current limit: 5)")) 105 }) 106 107 It("blocks AcceptStream until a new stream is available", func() { 108 strChan := make(chan *mockGenericStream) 109 go func() { 110 defer GinkgoRecover() 111 str, err := m.AcceptStream(context.Background()) 112 Expect(err).ToNot(HaveOccurred()) 113 strChan <- str 114 }() 115 Consistently(strChan).ShouldNot(Receive()) 116 str, err := m.GetOrOpenStream(1) 117 Expect(err).ToNot(HaveOccurred()) 118 Expect(str.num).To(Equal(protocol.StreamNum(1))) 119 var acceptedStr *mockGenericStream 120 Eventually(strChan).Should(Receive(&acceptedStr)) 121 Expect(acceptedStr.num).To(Equal(protocol.StreamNum(1))) 122 }) 123 124 It("unblocks AcceptStream when the context is canceled", func() { 125 ctx, cancel := context.WithCancel(context.Background()) 126 done := make(chan struct{}) 127 go func() { 128 defer GinkgoRecover() 129 _, err := m.AcceptStream(ctx) 130 Expect(err).To(MatchError("context canceled")) 131 close(done) 132 }() 133 Consistently(done).ShouldNot(BeClosed()) 134 cancel() 135 Eventually(done).Should(BeClosed()) 136 }) 137 138 It("unblocks AcceptStream when it is closed", func() { 139 testErr := errors.New("test error") 140 done := make(chan struct{}) 141 go func() { 142 defer GinkgoRecover() 143 _, err := m.AcceptStream(context.Background()) 144 Expect(err).To(MatchError(testErr)) 145 close(done) 146 }() 147 Consistently(done).ShouldNot(BeClosed()) 148 m.CloseWithError(testErr) 149 Eventually(done).Should(BeClosed()) 150 }) 151 152 It("errors AcceptStream immediately if it is closed", func() { 153 testErr := errors.New("test error") 154 m.CloseWithError(testErr) 155 _, err := m.AcceptStream(context.Background()) 156 Expect(err).To(MatchError(testErr)) 157 }) 158 159 It("closes all streams when CloseWithError is called", func() { 160 str1, err := m.GetOrOpenStream(1) 161 Expect(err).ToNot(HaveOccurred()) 162 str2, err := m.GetOrOpenStream(3) 163 Expect(err).ToNot(HaveOccurred()) 164 testErr := errors.New("test err") 165 m.CloseWithError(testErr) 166 Expect(str1.closed).To(BeTrue()) 167 Expect(str1.closeErr).To(MatchError(testErr)) 168 Expect(str2.closed).To(BeTrue()) 169 Expect(str2.closeErr).To(MatchError(testErr)) 170 }) 171 172 It("deletes streams", func() { 173 mockSender.EXPECT().queueControlFrame(gomock.Any()) 174 _, err := m.GetOrOpenStream(1) 175 Expect(err).ToNot(HaveOccurred()) 176 str, err := m.AcceptStream(context.Background()) 177 Expect(err).ToNot(HaveOccurred()) 178 Expect(str.num).To(Equal(protocol.StreamNum(1))) 179 Expect(m.DeleteStream(1)).To(Succeed()) 180 str, err = m.GetOrOpenStream(1) 181 Expect(err).ToNot(HaveOccurred()) 182 Expect(str).To(BeNil()) 183 }) 184 185 It("waits until a stream is accepted before actually deleting it", func() { 186 _, err := m.GetOrOpenStream(2) 187 Expect(err).ToNot(HaveOccurred()) 188 Expect(m.DeleteStream(2)).To(Succeed()) 189 str, err := m.AcceptStream(context.Background()) 190 Expect(err).ToNot(HaveOccurred()) 191 Expect(str.num).To(Equal(protocol.StreamNum(1))) 192 // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued 193 mockSender.EXPECT().queueControlFrame(gomock.Any()) 194 str, err = m.AcceptStream(context.Background()) 195 Expect(err).ToNot(HaveOccurred()) 196 Expect(str.num).To(Equal(protocol.StreamNum(2))) 197 }) 198 199 It("doesn't return a stream queued for deleting from GetOrOpenStream", func() { 200 str, err := m.GetOrOpenStream(1) 201 Expect(err).ToNot(HaveOccurred()) 202 Expect(str).ToNot(BeNil()) 203 Expect(m.DeleteStream(1)).To(Succeed()) 204 str, err = m.GetOrOpenStream(1) 205 Expect(err).ToNot(HaveOccurred()) 206 Expect(str).To(BeNil()) 207 // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued 208 mockSender.EXPECT().queueControlFrame(gomock.Any()) 209 str, err = m.AcceptStream(context.Background()) 210 Expect(err).ToNot(HaveOccurred()) 211 Expect(str).ToNot(BeNil()) 212 }) 213 214 It("errors when deleting a non-existing stream", func() { 215 err := m.DeleteStream(1337) 216 Expect(err).To(HaveOccurred()) 217 Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown incoming stream 1337")) 218 }) 219 220 It("sends MAX_STREAMS frames when streams are deleted", func() { 221 // open a bunch of streams 222 _, err := m.GetOrOpenStream(5) 223 Expect(err).ToNot(HaveOccurred()) 224 // accept all streams 225 for i := 0; i < 5; i++ { 226 _, err := m.AcceptStream(context.Background()) 227 Expect(err).ToNot(HaveOccurred()) 228 } 229 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 230 msf := f.(*wire.MaxStreamsFrame) 231 Expect(msf.Type).To(BeEquivalentTo(streamType)) 232 Expect(msf.MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1))) 233 checkFrameSerialization(f) 234 }) 235 Expect(m.DeleteStream(3)).To(Succeed()) 236 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 237 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2))) 238 checkFrameSerialization(f) 239 }) 240 Expect(m.DeleteStream(4)).To(Succeed()) 241 }) 242 243 Context("using high stream limits", func() { 244 BeforeEach(func() { maxNumStreams = uint64(protocol.MaxStreamCount) - 2 }) 245 246 It("doesn't send MAX_STREAMS frames if they would overflow 2^60 (the maximum stream count)", func() { 247 // open a bunch of streams 248 _, err := m.GetOrOpenStream(5) 249 Expect(err).ToNot(HaveOccurred()) 250 // accept all streams 251 for i := 0; i < 5; i++ { 252 _, err := m.AcceptStream(context.Background()) 253 Expect(err).ToNot(HaveOccurred()) 254 } 255 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 256 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1)) 257 checkFrameSerialization(f) 258 }) 259 Expect(m.DeleteStream(4)).To(Succeed()) 260 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 261 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount)) 262 checkFrameSerialization(f) 263 }) 264 Expect(m.DeleteStream(3)).To(Succeed()) 265 // at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent 266 Expect(m.DeleteStream(2)).To(Succeed()) 267 Expect(m.DeleteStream(1)).To(Succeed()) 268 }) 269 }) 270 271 Context("randomized tests", func() { 272 const num = 1000 273 274 BeforeEach(func() { maxNumStreams = num }) 275 276 It("opens and accepts streams", func() { 277 rand.Seed(GinkgoRandomSeed()) 278 ids := make([]protocol.StreamNum, num) 279 for i := 0; i < num; i++ { 280 ids[i] = protocol.StreamNum(i + 1) 281 } 282 rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] }) 283 284 const timeout = 5 * time.Second 285 done := make(chan struct{}, 2) 286 go func() { 287 defer GinkgoRecover() 288 ctx, cancel := context.WithTimeout(context.Background(), timeout) 289 defer cancel() 290 for i := 0; i < num; i++ { 291 _, err := m.AcceptStream(ctx) 292 Expect(err).ToNot(HaveOccurred()) 293 } 294 done <- struct{}{} 295 }() 296 297 go func() { 298 defer GinkgoRecover() 299 for i := 0; i < num; i++ { 300 _, err := m.GetOrOpenStream(ids[i]) 301 Expect(err).ToNot(HaveOccurred()) 302 } 303 done <- struct{}{} 304 }() 305 306 Eventually(done, timeout*3/2).Should(Receive()) 307 Eventually(done, timeout*3/2).Should(Receive()) 308 }) 309 }) 310 })