github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/http3/state_tracking_stream.go (about) 1 package http3 2 3 import ( 4 "context" 5 "errors" 6 "os" 7 "sync" 8 9 "github.com/metacubex/quic-go" 10 "github.com/metacubex/quic-go/internal/utils" 11 ) 12 13 var _ quic.Stream = &stateTrackingStream{} 14 15 // stateTrackingStream is an implementation of quic.Stream that delegates 16 // to an underlying stream 17 // it takes care of proxying send and receive errors onto an implementation of 18 // the errorSetter interface (intended to be occupied by a datagrammer) 19 // it is also responsible for clearing the stream based on its ID from its 20 // parent connection, this is done through the streamClearer interface when 21 // both the send and receive sides are closed 22 type stateTrackingStream struct { 23 quic.Stream 24 25 mx sync.Mutex 26 sendErr error 27 recvErr error 28 29 clearer streamClearer 30 setter errorSetter 31 } 32 33 type streamClearer interface { 34 clearStream(quic.StreamID) 35 } 36 37 type errorSetter interface { 38 SetSendError(error) 39 SetReceiveError(error) 40 } 41 42 func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream { 43 t := &stateTrackingStream{ 44 Stream: s, 45 clearer: clearer, 46 setter: setter, 47 } 48 49 utils.AfterFunc(s.Context(), func() { 50 t.closeSend(context.Cause(s.Context())) 51 }) 52 53 return t 54 } 55 56 func (s *stateTrackingStream) closeSend(e error) { 57 s.mx.Lock() 58 defer s.mx.Unlock() 59 60 // clear the stream the first time both the send 61 // and receive are finished 62 if s.sendErr == nil { 63 if s.recvErr != nil { 64 s.clearer.clearStream(s.StreamID()) 65 } 66 67 s.setter.SetSendError(e) 68 s.sendErr = e 69 } 70 } 71 72 func (s *stateTrackingStream) closeReceive(e error) { 73 s.mx.Lock() 74 defer s.mx.Unlock() 75 76 // clear the stream the first time both the send 77 // and receive are finished 78 if s.recvErr == nil { 79 if s.sendErr != nil { 80 s.clearer.clearStream(s.StreamID()) 81 } 82 83 s.setter.SetReceiveError(e) 84 s.recvErr = e 85 } 86 } 87 88 func (s *stateTrackingStream) Close() error { 89 s.closeSend(errors.New("write on closed stream")) 90 return s.Stream.Close() 91 } 92 93 func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) { 94 s.closeSend(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) 95 s.Stream.CancelWrite(e) 96 } 97 98 func (s *stateTrackingStream) Write(b []byte) (int, error) { 99 n, err := s.Stream.Write(b) 100 if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { 101 s.closeSend(err) 102 } 103 return n, err 104 } 105 106 func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) { 107 s.closeReceive(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) 108 s.Stream.CancelRead(e) 109 } 110 111 func (s *stateTrackingStream) Read(b []byte) (int, error) { 112 n, err := s.Stream.Read(b) 113 if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { 114 s.closeReceive(err) 115 } 116 return n, err 117 }