github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/http3/state_tracking_stream.go (about)

     1  package http3
     2  
     3  import (
     4  	"errors"
     5  	"sync"
     6  
     7  	"github.com/apernet/quic-go"
     8  )
     9  
    10  type streamState uint8
    11  
    12  const (
    13  	streamStateOpen streamState = iota
    14  	streamStateReceiveClosed
    15  	streamStateSendClosed
    16  	streamStateSendAndReceiveClosed
    17  )
    18  
    19  type stateTrackingStream struct {
    20  	quic.Stream
    21  
    22  	mx    sync.Mutex
    23  	state streamState
    24  
    25  	onStateChange func(streamState, error)
    26  }
    27  
    28  func newStateTrackingStream(s quic.Stream, onStateChange func(streamState, error)) *stateTrackingStream {
    29  	return &stateTrackingStream{
    30  		Stream:        s,
    31  		state:         streamStateOpen,
    32  		onStateChange: onStateChange,
    33  	}
    34  }
    35  
    36  var _ quic.Stream = &stateTrackingStream{}
    37  
    38  func (s *stateTrackingStream) closeSend(e error) {
    39  	s.mx.Lock()
    40  	defer s.mx.Unlock()
    41  
    42  	if s.state == streamStateReceiveClosed || s.state == streamStateSendAndReceiveClosed {
    43  		s.state = streamStateSendAndReceiveClosed
    44  	} else {
    45  		s.state = streamStateSendClosed
    46  	}
    47  	s.onStateChange(s.state, e)
    48  }
    49  
    50  func (s *stateTrackingStream) closeReceive(e error) {
    51  	s.mx.Lock()
    52  	defer s.mx.Unlock()
    53  
    54  	if s.state == streamStateSendClosed || s.state == streamStateSendAndReceiveClosed {
    55  		s.state = streamStateSendAndReceiveClosed
    56  	} else {
    57  		s.state = streamStateReceiveClosed
    58  	}
    59  	s.onStateChange(s.state, e)
    60  }
    61  
    62  func (s *stateTrackingStream) Close() error {
    63  	s.closeSend(errors.New("write on closed stream"))
    64  	return s.Stream.Close()
    65  }
    66  
    67  func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) {
    68  	s.closeSend(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e})
    69  	s.Stream.CancelWrite(e)
    70  }
    71  
    72  func (s *stateTrackingStream) Write(b []byte) (int, error) {
    73  	n, err := s.Stream.Write(b)
    74  	if err != nil {
    75  		s.closeSend(err)
    76  	}
    77  	return n, err
    78  }
    79  
    80  func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) {
    81  	s.closeReceive(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e})
    82  	s.Stream.CancelRead(e)
    83  }
    84  
    85  func (s *stateTrackingStream) Read(b []byte) (int, error) {
    86  	n, err := s.Stream.Read(b)
    87  	if err != nil {
    88  		s.closeReceive(err)
    89  	}
    90  	return n, err
    91  }