github.com/sagernet/quic-go@v0.43.1-beta.1/ech/receive_stream.go (about)

     1  package quic
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/sagernet/quic-go/internal/flowcontrol"
    10  	"github.com/sagernet/quic-go/internal/protocol"
    11  	"github.com/sagernet/quic-go/internal/qerr"
    12  	"github.com/sagernet/quic-go/internal/utils"
    13  	"github.com/sagernet/quic-go/internal/wire"
    14  )
    15  
    16  type receiveStreamI interface {
    17  	ReceiveStream
    18  
    19  	handleStreamFrame(*wire.StreamFrame) error
    20  	handleResetStreamFrame(*wire.ResetStreamFrame) error
    21  	closeForShutdown(error)
    22  	getWindowUpdate() protocol.ByteCount
    23  }
    24  
    25  type receiveStream struct {
    26  	mutex sync.Mutex
    27  
    28  	streamID protocol.StreamID
    29  
    30  	sender streamSender
    31  
    32  	frameQueue  *frameSorter
    33  	finalOffset protocol.ByteCount
    34  
    35  	currentFrame       []byte
    36  	currentFrameDone   func()
    37  	readPosInFrame     int
    38  	currentFrameIsLast bool // is the currentFrame the last frame on this stream
    39  
    40  	// Set once we read the io.EOF or the cancellation error.
    41  	// Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet.
    42  	errorRead           bool
    43  	completed           bool // set once we've called streamSender.onStreamCompleted
    44  	cancelledRemotely   bool
    45  	cancelledLocally    bool
    46  	cancelErr           *StreamError
    47  	closeForShutdownErr error
    48  
    49  	readChan chan struct{}
    50  	readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
    51  	deadline time.Time
    52  
    53  	flowController flowcontrol.StreamFlowController
    54  }
    55  
    56  var (
    57  	_ ReceiveStream  = &receiveStream{}
    58  	_ receiveStreamI = &receiveStream{}
    59  )
    60  
    61  func newReceiveStream(
    62  	streamID protocol.StreamID,
    63  	sender streamSender,
    64  	flowController flowcontrol.StreamFlowController,
    65  ) *receiveStream {
    66  	return &receiveStream{
    67  		streamID:       streamID,
    68  		sender:         sender,
    69  		flowController: flowController,
    70  		frameQueue:     newFrameSorter(),
    71  		readChan:       make(chan struct{}, 1),
    72  		readOnce:       make(chan struct{}, 1),
    73  		finalOffset:    protocol.MaxByteCount,
    74  	}
    75  }
    76  
    77  func (s *receiveStream) StreamID() protocol.StreamID {
    78  	return s.streamID
    79  }
    80  
    81  // Read implements io.Reader. It is not thread safe!
    82  func (s *receiveStream) Read(p []byte) (int, error) {
    83  	// Concurrent use of Read is not permitted (and doesn't make any sense),
    84  	// but sometimes people do it anyway.
    85  	// Make sure that we only execute one call at any given time to avoid hard to debug failures.
    86  	s.readOnce <- struct{}{}
    87  	defer func() { <-s.readOnce }()
    88  
    89  	s.mutex.Lock()
    90  	n, err := s.readImpl(p)
    91  	completed := s.isNewlyCompleted()
    92  	s.mutex.Unlock()
    93  
    94  	if completed {
    95  		s.sender.onStreamCompleted(s.streamID)
    96  	}
    97  	return n, err
    98  }
    99  
   100  func (s *receiveStream) isNewlyCompleted() bool {
   101  	if s.completed {
   102  		return false
   103  	}
   104  	// We need to know the final offset (either via FIN or RESET_STREAM) for flow control accounting.
   105  	if s.finalOffset == protocol.MaxByteCount {
   106  		return false
   107  	}
   108  	// We're done with the stream if it was cancelled locally...
   109  	if s.cancelledLocally {
   110  		s.completed = true
   111  		return true
   112  	}
   113  	// ... or if the error (either io.EOF or the reset error) was read
   114  	if s.errorRead {
   115  		s.completed = true
   116  		return true
   117  	}
   118  	return false
   119  }
   120  
   121  func (s *receiveStream) readImpl(p []byte) (int, error) {
   122  	if s.currentFrameIsLast && s.currentFrame == nil {
   123  		s.errorRead = true
   124  		return 0, io.EOF
   125  	}
   126  	if s.cancelledRemotely || s.cancelledLocally {
   127  		s.errorRead = true
   128  		return 0, s.cancelErr
   129  	}
   130  	if s.closeForShutdownErr != nil {
   131  		return 0, s.closeForShutdownErr
   132  	}
   133  
   134  	var bytesRead int
   135  	var deadlineTimer *utils.Timer
   136  	for bytesRead < len(p) {
   137  		if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) {
   138  			s.dequeueNextFrame()
   139  		}
   140  		if s.currentFrame == nil && bytesRead > 0 {
   141  			return bytesRead, s.closeForShutdownErr
   142  		}
   143  
   144  		for {
   145  			// Stop waiting on errors
   146  			if s.closeForShutdownErr != nil {
   147  				return bytesRead, s.closeForShutdownErr
   148  			}
   149  			if s.cancelledRemotely || s.cancelledLocally {
   150  				s.errorRead = true
   151  				return 0, s.cancelErr
   152  			}
   153  
   154  			deadline := s.deadline
   155  			if !deadline.IsZero() {
   156  				if !time.Now().Before(deadline) {
   157  					return bytesRead, errDeadline
   158  				}
   159  				if deadlineTimer == nil {
   160  					deadlineTimer = utils.NewTimer()
   161  					defer deadlineTimer.Stop()
   162  				}
   163  				deadlineTimer.Reset(deadline)
   164  			}
   165  
   166  			if s.currentFrame != nil || s.currentFrameIsLast {
   167  				break
   168  			}
   169  
   170  			s.mutex.Unlock()
   171  			if deadline.IsZero() {
   172  				<-s.readChan
   173  			} else {
   174  				select {
   175  				case <-s.readChan:
   176  				case <-deadlineTimer.Chan():
   177  					deadlineTimer.SetRead()
   178  				}
   179  			}
   180  			s.mutex.Lock()
   181  			if s.currentFrame == nil {
   182  				s.dequeueNextFrame()
   183  			}
   184  		}
   185  
   186  		if bytesRead > len(p) {
   187  			return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
   188  		}
   189  		if s.readPosInFrame > len(s.currentFrame) {
   190  			return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
   191  		}
   192  
   193  		m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:])
   194  		s.readPosInFrame += m
   195  		bytesRead += m
   196  
   197  		// when a RESET_STREAM was received, the flow controller was already
   198  		// informed about the final byteOffset for this stream
   199  		if !s.cancelledRemotely {
   200  			s.flowController.AddBytesRead(protocol.ByteCount(m))
   201  		}
   202  
   203  		if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
   204  			s.currentFrame = nil
   205  			if s.currentFrameDone != nil {
   206  				s.currentFrameDone()
   207  			}
   208  			s.errorRead = true
   209  			return bytesRead, io.EOF
   210  		}
   211  	}
   212  	return bytesRead, nil
   213  }
   214  
   215  func (s *receiveStream) dequeueNextFrame() {
   216  	var offset protocol.ByteCount
   217  	// We're done with the last frame. Release the buffer.
   218  	if s.currentFrameDone != nil {
   219  		s.currentFrameDone()
   220  	}
   221  	offset, s.currentFrame, s.currentFrameDone = s.frameQueue.Pop()
   222  	s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset
   223  	s.readPosInFrame = 0
   224  }
   225  
   226  func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
   227  	s.mutex.Lock()
   228  	s.cancelReadImpl(errorCode)
   229  	completed := s.isNewlyCompleted()
   230  	s.mutex.Unlock()
   231  
   232  	if completed {
   233  		s.flowController.Abandon()
   234  		s.sender.onStreamCompleted(s.streamID)
   235  	}
   236  }
   237  
   238  func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) {
   239  	if s.cancelledLocally { // duplicate call to CancelRead
   240  		return
   241  	}
   242  	s.cancelledLocally = true
   243  	if s.errorRead || s.cancelledRemotely {
   244  		return
   245  	}
   246  	s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
   247  	s.signalRead()
   248  	s.sender.queueControlFrame(&wire.StopSendingFrame{
   249  		StreamID:  s.streamID,
   250  		ErrorCode: errorCode,
   251  	})
   252  }
   253  
   254  func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
   255  	s.mutex.Lock()
   256  	err := s.handleStreamFrameImpl(frame)
   257  	completed := s.isNewlyCompleted()
   258  	s.mutex.Unlock()
   259  
   260  	if completed {
   261  		s.flowController.Abandon()
   262  		s.sender.onStreamCompleted(s.streamID)
   263  	}
   264  	return err
   265  }
   266  
   267  func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) error {
   268  	maxOffset := frame.Offset + frame.DataLen()
   269  	if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin); err != nil {
   270  		return err
   271  	}
   272  	if frame.Fin {
   273  		s.finalOffset = maxOffset
   274  	}
   275  	if s.cancelledLocally {
   276  		return nil
   277  	}
   278  	if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil {
   279  		return err
   280  	}
   281  	s.signalRead()
   282  	return nil
   283  }
   284  
   285  func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error {
   286  	s.mutex.Lock()
   287  	err := s.handleResetStreamFrameImpl(frame)
   288  	completed := s.isNewlyCompleted()
   289  	s.mutex.Unlock()
   290  
   291  	if completed {
   292  		s.sender.onStreamCompleted(s.streamID)
   293  	}
   294  	return err
   295  }
   296  
   297  func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) error {
   298  	if s.closeForShutdownErr != nil {
   299  		return nil
   300  	}
   301  	if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil {
   302  		return err
   303  	}
   304  	s.finalOffset = frame.FinalSize
   305  
   306  	// ignore duplicate RESET_STREAM frames for this stream (after checking their final offset)
   307  	if s.cancelledRemotely {
   308  		return nil
   309  	}
   310  	s.flowController.Abandon()
   311  	// don't save the error if the RESET_STREAM frames was received after CancelRead was called
   312  	if s.cancelledLocally {
   313  		return nil
   314  	}
   315  	s.cancelledRemotely = true
   316  	s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: frame.ErrorCode, Remote: true}
   317  	s.signalRead()
   318  	return nil
   319  }
   320  
   321  func (s *receiveStream) SetReadDeadline(t time.Time) error {
   322  	s.mutex.Lock()
   323  	s.deadline = t
   324  	s.mutex.Unlock()
   325  	s.signalRead()
   326  	return nil
   327  }
   328  
   329  // CloseForShutdown closes a stream abruptly.
   330  // It makes Read unblock (and return the error) immediately.
   331  // The peer will NOT be informed about this: the stream is closed without sending a FIN or RESET.
   332  func (s *receiveStream) closeForShutdown(err error) {
   333  	s.mutex.Lock()
   334  	s.closeForShutdownErr = err
   335  	s.mutex.Unlock()
   336  	s.signalRead()
   337  }
   338  
   339  func (s *receiveStream) getWindowUpdate() protocol.ByteCount {
   340  	return s.flowController.GetWindowUpdate()
   341  }
   342  
   343  // signalRead performs a non-blocking send on the readChan
   344  func (s *receiveStream) signalRead() {
   345  	select {
   346  	case s.readChan <- struct{}{}:
   347  	default:
   348  	}
   349  }