github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/http3/state_tracking_stream.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"os"
     7  	"sync"
     8  
     9  	"github.com/metacubex/quic-go"
    10  	"github.com/metacubex/quic-go/internal/utils"
    11  )
    12  
    13  var _ quic.Stream = &stateTrackingStream{}
    14  
    15  // stateTrackingStream is an implementation of quic.Stream that delegates
    16  // to an underlying stream
    17  // it takes care of proxying send and receive errors onto an implementation of
    18  // the errorSetter interface (intended to be occupied by a datagrammer)
    19  // it is also responsible for clearing the stream based on its ID from its
    20  // parent connection, this is done through the streamClearer interface when
    21  // both the send and receive sides are closed
    22  type stateTrackingStream struct {
    23  	quic.Stream
    24  
    25  	mx      sync.Mutex
    26  	sendErr error
    27  	recvErr error
    28  
    29  	clearer streamClearer
    30  	setter  errorSetter
    31  }
    32  
    33  type streamClearer interface {
    34  	clearStream(quic.StreamID)
    35  }
    36  
    37  type errorSetter interface {
    38  	SetSendError(error)
    39  	SetReceiveError(error)
    40  }
    41  
    42  func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream {
    43  	t := &stateTrackingStream{
    44  		Stream:  s,
    45  		clearer: clearer,
    46  		setter:  setter,
    47  	}
    48  
    49  	utils.AfterFunc(s.Context(), func() {
    50  		t.closeSend(context.Cause(s.Context()))
    51  	})
    52  
    53  	return t
    54  }
    55  
    56  func (s *stateTrackingStream) closeSend(e error) {
    57  	s.mx.Lock()
    58  	defer s.mx.Unlock()
    59  
    60  	// clear the stream the first time both the send
    61  	// and receive are finished
    62  	if s.sendErr == nil {
    63  		if s.recvErr != nil {
    64  			s.clearer.clearStream(s.StreamID())
    65  		}
    66  
    67  		s.setter.SetSendError(e)
    68  		s.sendErr = e
    69  	}
    70  }
    71  
    72  func (s *stateTrackingStream) closeReceive(e error) {
    73  	s.mx.Lock()
    74  	defer s.mx.Unlock()
    75  
    76  	// clear the stream the first time both the send
    77  	// and receive are finished
    78  	if s.recvErr == nil {
    79  		if s.sendErr != nil {
    80  			s.clearer.clearStream(s.StreamID())
    81  		}
    82  
    83  		s.setter.SetReceiveError(e)
    84  		s.recvErr = e
    85  	}
    86  }
    87  
    88  func (s *stateTrackingStream) Close() error {
    89  	s.closeSend(errors.New("write on closed stream"))
    90  	return s.Stream.Close()
    91  }
    92  
    93  func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) {
    94  	s.closeSend(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e})
    95  	s.Stream.CancelWrite(e)
    96  }
    97  
    98  func (s *stateTrackingStream) Write(b []byte) (int, error) {
    99  	n, err := s.Stream.Write(b)
   100  	if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
   101  		s.closeSend(err)
   102  	}
   103  	return n, err
   104  }
   105  
   106  func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) {
   107  	s.closeReceive(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e})
   108  	s.Stream.CancelRead(e)
   109  }
   110  
   111  func (s *stateTrackingStream) Read(b []byte) (int, error) {
   112  	n, err := s.Stream.Read(b)
   113  	if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
   114  		s.closeReceive(err)
   115  	}
   116  	return n, err
   117  }