github.com/sagernet/quic-go@v0.43.1-beta.1/http3_ech/state_tracking_stream.go (about) 1 package http3 2 3 import ( 4 "errors" 5 "sync" 6 7 "github.com/sagernet/quic-go/ech" 8 ) 9 10 type streamState uint8 11 12 const ( 13 streamStateOpen streamState = iota 14 streamStateReceiveClosed 15 streamStateSendClosed 16 streamStateSendAndReceiveClosed 17 ) 18 19 type stateTrackingStream struct { 20 quic.Stream 21 22 mx sync.Mutex 23 state streamState 24 25 onStateChange func(streamState, error) 26 } 27 28 func newStateTrackingStream(s quic.Stream, onStateChange func(streamState, error)) *stateTrackingStream { 29 return &stateTrackingStream{ 30 Stream: s, 31 state: streamStateOpen, 32 onStateChange: onStateChange, 33 } 34 } 35 36 var _ quic.Stream = &stateTrackingStream{} 37 38 func (s *stateTrackingStream) closeSend(e error) { 39 s.mx.Lock() 40 defer s.mx.Unlock() 41 42 if s.state == streamStateReceiveClosed || s.state == streamStateSendAndReceiveClosed { 43 s.state = streamStateSendAndReceiveClosed 44 } else { 45 s.state = streamStateSendClosed 46 } 47 s.onStateChange(s.state, e) 48 } 49 50 func (s *stateTrackingStream) closeReceive(e error) { 51 s.mx.Lock() 52 defer s.mx.Unlock() 53 54 if s.state == streamStateSendClosed || s.state == streamStateSendAndReceiveClosed { 55 s.state = streamStateSendAndReceiveClosed 56 } else { 57 s.state = streamStateReceiveClosed 58 } 59 s.onStateChange(s.state, e) 60 } 61 62 func (s *stateTrackingStream) Close() error { 63 s.closeSend(errors.New("write on closed stream")) 64 return s.Stream.Close() 65 } 66 67 func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) { 68 s.closeSend(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) 69 s.Stream.CancelWrite(e) 70 } 71 72 func (s *stateTrackingStream) Write(b []byte) (int, error) { 73 n, err := s.Stream.Write(b) 74 if err != nil { 75 s.closeSend(err) 76 } 77 return n, err 78 } 79 80 func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) { 81 s.closeReceive(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) 82 s.Stream.CancelRead(e) 83 } 84 85 func (s *stateTrackingStream) Read(b []byte) (int, error) { 86 n, err := s.Stream.Read(b) 87 if err != nil { 88 s.closeReceive(err) 89 } 90 return n, err 91 }