github.com/quic-go/quic-go@v0.44.0/http3/state_tracking_stream.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"os"
     7  	"sync"
     8  
     9  	"github.com/quic-go/quic-go"
    10  )
    11  
    12  var _ quic.Stream = &stateTrackingStream{}
    13  
    14  // stateTrackingStream is an implementation of quic.Stream that delegates
    15  // to an underlying stream
    16  // it takes care of proxying send and receive errors onto an implementation of
    17  // the errorSetter interface (intended to be occupied by a datagrammer)
    18  // it is also responsible for clearing the stream based on its ID from its
    19  // parent connection, this is done through the streamClearer interface when
    20  // both the send and receive sides are closed
    21  type stateTrackingStream struct {
    22  	quic.Stream
    23  
    24  	mx      sync.Mutex
    25  	sendErr error
    26  	recvErr error
    27  
    28  	clearer streamClearer
    29  	setter  errorSetter
    30  }
    31  
    32  type streamClearer interface {
    33  	clearStream(quic.StreamID)
    34  }
    35  
    36  type errorSetter interface {
    37  	SetSendError(error)
    38  	SetReceiveError(error)
    39  }
    40  
    41  func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream {
    42  	t := &stateTrackingStream{
    43  		Stream:  s,
    44  		clearer: clearer,
    45  		setter:  setter,
    46  	}
    47  
    48  	context.AfterFunc(s.Context(), func() {
    49  		t.closeSend(context.Cause(s.Context()))
    50  	})
    51  
    52  	return t
    53  }
    54  
    55  func (s *stateTrackingStream) closeSend(e error) {
    56  	s.mx.Lock()
    57  	defer s.mx.Unlock()
    58  
    59  	// clear the stream the first time both the send
    60  	// and receive are finished
    61  	if s.sendErr == nil {
    62  		if s.recvErr != nil {
    63  			s.clearer.clearStream(s.StreamID())
    64  		}
    65  
    66  		s.setter.SetSendError(e)
    67  		s.sendErr = e
    68  	}
    69  }
    70  
    71  func (s *stateTrackingStream) closeReceive(e error) {
    72  	s.mx.Lock()
    73  	defer s.mx.Unlock()
    74  
    75  	// clear the stream the first time both the send
    76  	// and receive are finished
    77  	if s.recvErr == nil {
    78  		if s.sendErr != nil {
    79  			s.clearer.clearStream(s.StreamID())
    80  		}
    81  
    82  		s.setter.SetReceiveError(e)
    83  		s.recvErr = e
    84  	}
    85  }
    86  
    87  func (s *stateTrackingStream) Close() error {
    88  	s.closeSend(errors.New("write on closed stream"))
    89  	return s.Stream.Close()
    90  }
    91  
    92  func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) {
    93  	s.closeSend(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e})
    94  	s.Stream.CancelWrite(e)
    95  }
    96  
    97  func (s *stateTrackingStream) Write(b []byte) (int, error) {
    98  	n, err := s.Stream.Write(b)
    99  	if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
   100  		s.closeSend(err)
   101  	}
   102  	return n, err
   103  }
   104  
   105  func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) {
   106  	s.closeReceive(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e})
   107  	s.Stream.CancelRead(e)
   108  }
   109  
   110  func (s *stateTrackingStream) Read(b []byte) (int, error) {
   111  	n, err := s.Stream.Read(b)
   112  	if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
   113  		s.closeReceive(err)
   114  	}
   115  	return n, err
   116  }