github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/http3/state_tracking_stream_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "io" 8 "os" 9 10 "github.com/metacubex/quic-go" 11 mockquic "github.com/metacubex/quic-go/internal/mocks/quic" 12 13 . "github.com/onsi/ginkgo/v2" 14 . "github.com/onsi/gomega" 15 "go.uber.org/mock/gomock" 16 ) 17 18 var someStreamID = quic.StreamID(12) 19 20 var _ = Describe("State Tracking Stream", func() { 21 It("recognizes when the receive side is closed", func() { 22 qstr := mockquic.NewMockStream(mockCtrl) 23 qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) 24 qstr.EXPECT().Context().Return(context.Background()).AnyTimes() 25 26 var ( 27 clearer mockStreamClearer 28 setter mockErrorSetter 29 str = newStateTrackingStream(qstr, &clearer, &setter) 30 ) 31 32 buf := bytes.NewBuffer([]byte("foobar")) 33 qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 34 for i := 0; i < 3; i++ { 35 _, err := str.Read([]byte{0}) 36 Expect(err).ToNot(HaveOccurred()) 37 Expect(clearer.cleared).To(BeNil()) 38 Expect(setter.recvErrs).To(BeEmpty()) 39 Expect(setter.sendErrs).To(BeEmpty()) 40 } 41 _, err := io.ReadAll(str) 42 Expect(err).ToNot(HaveOccurred()) 43 Expect(clearer.cleared).To(BeNil()) 44 Expect(setter.recvErrs).To(HaveLen(1)) 45 Expect(setter.recvErrs[0]).To(Equal(io.EOF)) 46 Expect(setter.sendErrs).To(BeEmpty()) 47 }) 48 49 It("recognizes local read cancellations", func() { 50 qstr := mockquic.NewMockStream(mockCtrl) 51 qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) 52 qstr.EXPECT().Context().Return(context.Background()).AnyTimes() 53 54 var ( 55 clearer mockStreamClearer 56 setter mockErrorSetter 57 str = newStateTrackingStream(qstr, &clearer, &setter) 58 ) 59 60 buf := bytes.NewBuffer([]byte("foobar")) 61 qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 62 qstr.EXPECT().CancelRead(quic.StreamErrorCode(1337)) 63 _, err := str.Read(make([]byte, 3)) 64 Expect(err).ToNot(HaveOccurred()) 65 Expect(clearer.cleared).To(BeNil()) 66 Expect(setter.recvErrs).To(BeEmpty()) 67 Expect(setter.sendErrs).To(BeEmpty()) 68 69 str.CancelRead(1337) 70 Expect(clearer.cleared).To(BeNil()) 71 Expect(setter.recvErrs).To(HaveLen(1)) 72 Expect(setter.recvErrs[0]).To(Equal(&quic.StreamError{StreamID: someStreamID, ErrorCode: 1337})) 73 Expect(setter.sendErrs).To(BeEmpty()) 74 }) 75 76 It("recognizes remote cancellations", func() { 77 qstr := mockquic.NewMockStream(mockCtrl) 78 qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) 79 qstr.EXPECT().Context().Return(context.Background()).AnyTimes() 80 81 var ( 82 clearer mockStreamClearer 83 setter mockErrorSetter 84 str = newStateTrackingStream(qstr, &clearer, &setter) 85 ) 86 87 testErr := errors.New("test error") 88 qstr.EXPECT().Read(gomock.Any()).Return(0, testErr) 89 _, err := str.Read(make([]byte, 3)) 90 Expect(err).To(MatchError(testErr)) 91 Expect(clearer.cleared).To(BeNil()) 92 Expect(setter.recvErrs).To(HaveLen(1)) 93 Expect(setter.recvErrs[0]).To(Equal(testErr)) 94 Expect(setter.sendErrs).To(BeEmpty()) 95 }) 96 97 It("doesn't misinterpret read deadline errors", func() { 98 qstr := mockquic.NewMockStream(mockCtrl) 99 qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) 100 qstr.EXPECT().Context().Return(context.Background()).AnyTimes() 101 102 var ( 103 clearer mockStreamClearer 104 setter mockErrorSetter 105 str = newStateTrackingStream(qstr, &clearer, &setter) 106 ) 107 108 qstr.EXPECT().Read(gomock.Any()).Return(0, os.ErrDeadlineExceeded) 109 _, err := str.Read(make([]byte, 3)) 110 Expect(err).To(MatchError(os.ErrDeadlineExceeded)) 111 Expect(clearer.cleared).To(BeNil()) 112 Expect(setter.recvErrs).To(BeEmpty()) 113 Expect(setter.sendErrs).To(BeEmpty()) 114 }) 115 116 It("recognizes when the send side is closed, when write errors", func() { 117 qstr := mockquic.NewMockStream(mockCtrl) 118 qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) 119 qstr.EXPECT().Context().Return(context.Background()).AnyTimes() 120 121 var ( 122 clearer mockStreamClearer 123 setter mockErrorSetter 124 str = newStateTrackingStream(qstr, &clearer, &setter) 125 ) 126 127 testErr := errors.New("test error") 128 qstr.EXPECT().Write([]byte("foo")).Return(3, nil) 129 qstr.EXPECT().Write([]byte("bar")).Return(0, testErr) 130 131 _, err := str.Write([]byte("foo")) 132 Expect(err).ToNot(HaveOccurred()) 133 Expect(clearer.cleared).To(BeNil()) 134 Expect(setter.recvErrs).To(BeEmpty()) 135 Expect(setter.sendErrs).To(BeEmpty()) 136 137 _, err = str.Write([]byte("bar")) 138 Expect(err).To(MatchError(testErr)) 139 Expect(clearer.cleared).To(BeNil()) 140 Expect(setter.recvErrs).To(BeEmpty()) 141 Expect(setter.sendErrs).To(HaveLen(1)) 142 Expect(setter.sendErrs[0]).To(Equal(testErr)) 143 }) 144 145 It("recognizes when the send side is closed, when write errors", func() { 146 qstr := mockquic.NewMockStream(mockCtrl) 147 qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) 148 qstr.EXPECT().Context().Return(context.Background()).AnyTimes() 149 150 var ( 151 clearer mockStreamClearer 152 setter mockErrorSetter 153 str = newStateTrackingStream(qstr, &clearer, &setter) 154 ) 155 156 qstr.EXPECT().Write([]byte("foo")).Return(0, os.ErrDeadlineExceeded) 157 Expect(clearer.cleared).To(BeNil()) 158 Expect(setter.recvErrs).To(BeEmpty()) 159 Expect(setter.sendErrs).To(BeEmpty()) 160 161 _, err := str.Write([]byte("foo")) 162 Expect(err).To(MatchError(os.ErrDeadlineExceeded)) 163 Expect(clearer.cleared).To(BeNil()) 164 Expect(setter.recvErrs).To(BeEmpty()) 165 Expect(setter.sendErrs).To(BeEmpty()) 166 }) 167 168 It("recognizes when the send side is closed, when CancelWrite is called", func() { 169 qstr := mockquic.NewMockStream(mockCtrl) 170 qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) 171 qstr.EXPECT().Context().Return(context.Background()).AnyTimes() 172 173 var ( 174 clearer mockStreamClearer 175 setter mockErrorSetter 176 str = newStateTrackingStream(qstr, &clearer, &setter) 177 ) 178 179 qstr.EXPECT().Write(gomock.Any()) 180 qstr.EXPECT().CancelWrite(quic.StreamErrorCode(1337)) 181 _, err := str.Write([]byte("foobar")) 182 Expect(err).ToNot(HaveOccurred()) 183 Expect(clearer.cleared).To(BeNil()) 184 Expect(setter.recvErrs).To(BeEmpty()) 185 Expect(setter.sendErrs).To(BeEmpty()) 186 187 str.CancelWrite(1337) 188 Expect(clearer.cleared).To(BeNil()) 189 Expect(setter.recvErrs).To(BeEmpty()) 190 Expect(setter.sendErrs).To(HaveLen(1)) 191 Expect(setter.sendErrs[0]).To(Equal(&quic.StreamError{StreamID: someStreamID, ErrorCode: 1337})) 192 }) 193 194 It("recognizes when the send side is closed, when the stream context is canceled", func() { 195 qstr := mockquic.NewMockStream(mockCtrl) 196 qstr.EXPECT().StreamID().AnyTimes() 197 ctx, cancel := context.WithCancelCause(context.Background()) 198 qstr.EXPECT().Context().Return(ctx).AnyTimes() 199 200 var ( 201 clearer mockStreamClearer 202 setter = mockErrorSetter{ 203 sendSent: make(chan struct{}), 204 } 205 ) 206 207 _ = newStateTrackingStream(qstr, &clearer, &setter) 208 Expect(clearer.cleared).To(BeNil()) 209 Expect(setter.recvErrs).To(BeEmpty()) 210 Expect(setter.sendErrs).To(BeEmpty()) 211 212 testErr := errors.New("test error") 213 cancel(testErr) 214 Eventually(setter.sendSent).Should(BeClosed()) 215 Expect(clearer.cleared).To(BeNil()) 216 Expect(setter.recvErrs).To(BeEmpty()) 217 Expect(setter.sendErrs).To(HaveLen(1)) 218 Expect(setter.sendErrs[0]).To(Equal(testErr)) 219 }) 220 221 It("clears the stream when receive is closed followed by send is closed", func() { 222 qstr := mockquic.NewMockStream(mockCtrl) 223 qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) 224 qstr.EXPECT().Context().Return(context.Background()).AnyTimes() 225 226 var ( 227 clearer mockStreamClearer 228 setter mockErrorSetter 229 str = newStateTrackingStream(qstr, &clearer, &setter) 230 ) 231 232 buf := bytes.NewBuffer([]byte("foobar")) 233 qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 234 _, err := io.ReadAll(str) 235 Expect(err).ToNot(HaveOccurred()) 236 237 Expect(clearer.cleared).To(BeNil()) 238 Expect(setter.recvErrs).To(HaveLen(1)) 239 Expect(setter.recvErrs[0]).To(Equal(io.EOF)) 240 241 testErr := errors.New("test error") 242 qstr.EXPECT().Write([]byte("bar")).Return(0, testErr) 243 244 _, err = str.Write([]byte("bar")) 245 Expect(err).To(MatchError(testErr)) 246 Expect(setter.sendErrs).To(HaveLen(1)) 247 Expect(setter.sendErrs[0]).To(Equal(testErr)) 248 249 Expect(clearer.cleared).To(Equal(&someStreamID)) 250 }) 251 252 It("clears the stream when send is closed followed by receive is closed", func() { 253 qstr := mockquic.NewMockStream(mockCtrl) 254 qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) 255 qstr.EXPECT().Context().Return(context.Background()).AnyTimes() 256 257 var ( 258 clearer mockStreamClearer 259 setter mockErrorSetter 260 str = newStateTrackingStream(qstr, &clearer, &setter) 261 ) 262 263 testErr := errors.New("test error") 264 qstr.EXPECT().Write([]byte("bar")).Return(0, testErr) 265 266 _, err := str.Write([]byte("bar")) 267 Expect(err).To(MatchError(testErr)) 268 Expect(clearer.cleared).To(BeNil()) 269 Expect(setter.sendErrs).To(HaveLen(1)) 270 Expect(setter.sendErrs[0]).To(Equal(testErr)) 271 272 buf := bytes.NewBuffer([]byte("foobar")) 273 qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 274 275 _, err = io.ReadAll(str) 276 Expect(err).ToNot(HaveOccurred()) 277 Expect(setter.recvErrs).To(HaveLen(1)) 278 Expect(setter.recvErrs[0]).To(Equal(io.EOF)) 279 280 Expect(clearer.cleared).To(Equal(&someStreamID)) 281 }) 282 }) 283 284 type mockStreamClearer struct { 285 cleared *quic.StreamID 286 } 287 288 func (s *mockStreamClearer) clearStream(id quic.StreamID) { 289 s.cleared = &id 290 } 291 292 type mockErrorSetter struct { 293 sendErrs []error 294 recvErrs []error 295 296 sendSent chan struct{} 297 } 298 299 func (e *mockErrorSetter) SetSendError(err error) { 300 e.sendErrs = append(e.sendErrs, err) 301 302 if e.sendSent != nil { 303 close(e.sendSent) 304 } 305 } 306 307 func (e *mockErrorSetter) SetReceiveError(err error) { 308 e.recvErrs = append(e.recvErrs, err) 309 }