github.com/quic-go/quic-go@v0.44.0/http3/state_tracking_stream.go (about) 1 package http3 2 3 import ( 4 "context" 5 "errors" 6 "os" 7 "sync" 8 9 "github.com/quic-go/quic-go" 10 ) 11 12 var _ quic.Stream = &stateTrackingStream{} 13 14 // stateTrackingStream is an implementation of quic.Stream that delegates 15 // to an underlying stream 16 // it takes care of proxying send and receive errors onto an implementation of 17 // the errorSetter interface (intended to be occupied by a datagrammer) 18 // it is also responsible for clearing the stream based on its ID from its 19 // parent connection, this is done through the streamClearer interface when 20 // both the send and receive sides are closed 21 type stateTrackingStream struct { 22 quic.Stream 23 24 mx sync.Mutex 25 sendErr error 26 recvErr error 27 28 clearer streamClearer 29 setter errorSetter 30 } 31 32 type streamClearer interface { 33 clearStream(quic.StreamID) 34 } 35 36 type errorSetter interface { 37 SetSendError(error) 38 SetReceiveError(error) 39 } 40 41 func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream { 42 t := &stateTrackingStream{ 43 Stream: s, 44 clearer: clearer, 45 setter: setter, 46 } 47 48 context.AfterFunc(s.Context(), func() { 49 t.closeSend(context.Cause(s.Context())) 50 }) 51 52 return t 53 } 54 55 func (s *stateTrackingStream) closeSend(e error) { 56 s.mx.Lock() 57 defer s.mx.Unlock() 58 59 // clear the stream the first time both the send 60 // and receive are finished 61 if s.sendErr == nil { 62 if s.recvErr != nil { 63 s.clearer.clearStream(s.StreamID()) 64 } 65 66 s.setter.SetSendError(e) 67 s.sendErr = e 68 } 69 } 70 71 func (s *stateTrackingStream) closeReceive(e error) { 72 s.mx.Lock() 73 defer s.mx.Unlock() 74 75 // clear the stream the first time both the send 76 // and receive are finished 77 if s.recvErr == nil { 78 if s.sendErr != nil { 79 s.clearer.clearStream(s.StreamID()) 80 } 81 82 s.setter.SetReceiveError(e) 83 s.recvErr = e 84 } 85 } 86 87 func (s *stateTrackingStream) Close() error { 88 s.closeSend(errors.New("write on closed stream")) 89 return s.Stream.Close() 90 } 91 92 func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) { 93 s.closeSend(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) 94 s.Stream.CancelWrite(e) 95 } 96 97 func (s *stateTrackingStream) Write(b []byte) (int, error) { 98 n, err := s.Stream.Write(b) 99 if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { 100 s.closeSend(err) 101 } 102 return n, err 103 } 104 105 func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) { 106 s.closeReceive(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) 107 s.Stream.CancelRead(e) 108 } 109 110 func (s *stateTrackingStream) Read(b []byte) (int, error) { 111 n, err := s.Stream.Read(b) 112 if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { 113 s.closeReceive(err) 114 } 115 return n, err 116 }