github.com/sagernet/quic-go@v0.43.1-beta.1/send_stream.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/sagernet/quic-go/internal/ackhandler"
    10  	"github.com/sagernet/quic-go/internal/flowcontrol"
    11  	"github.com/sagernet/quic-go/internal/protocol"
    12  	"github.com/sagernet/quic-go/internal/qerr"
    13  	"github.com/sagernet/quic-go/internal/utils"
    14  	"github.com/sagernet/quic-go/internal/wire"
    15  )
    16  
    17  type sendStreamI interface {
    18  	SendStream
    19  	handleStopSendingFrame(*wire.StopSendingFrame)
    20  	hasData() bool
    21  	popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (frame ackhandler.StreamFrame, ok, hasMore bool)
    22  	closeForShutdown(error)
    23  	updateSendWindow(protocol.ByteCount)
    24  }
    25  
    26  type sendStream struct {
    27  	mutex sync.Mutex
    28  
    29  	numOutstandingFrames int64
    30  	retransmissionQueue  []*wire.StreamFrame
    31  
    32  	ctx       context.Context
    33  	ctxCancel context.CancelCauseFunc
    34  
    35  	streamID protocol.StreamID
    36  	sender   streamSender
    37  
    38  	writeOffset protocol.ByteCount
    39  
    40  	cancelWriteErr      error
    41  	closeForShutdownErr error
    42  
    43  	finishedWriting bool // set once Close() is called
    44  	finSent         bool // set when a STREAM_FRAME with FIN bit has been sent
    45  	// Set when the application knows about the cancellation.
    46  	// This can happen because the application called CancelWrite,
    47  	// or because Write returned the error (for remote cancellations).
    48  	cancellationFlagged bool
    49  	completed           bool // set when this stream has been reported to the streamSender as completed
    50  
    51  	dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
    52  	nextFrame      *wire.StreamFrame
    53  
    54  	writeChan chan struct{}
    55  	writeOnce chan struct{}
    56  	deadline  time.Time
    57  
    58  	flowController flowcontrol.StreamFlowController
    59  }
    60  
    61  var (
    62  	_ SendStream  = &sendStream{}
    63  	_ sendStreamI = &sendStream{}
    64  )
    65  
    66  func newSendStream(
    67  	ctx context.Context,
    68  	streamID protocol.StreamID,
    69  	sender streamSender,
    70  	flowController flowcontrol.StreamFlowController,
    71  ) *sendStream {
    72  	s := &sendStream{
    73  		streamID:       streamID,
    74  		sender:         sender,
    75  		flowController: flowController,
    76  		writeChan:      make(chan struct{}, 1),
    77  		writeOnce:      make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write
    78  	}
    79  	s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
    80  	return s
    81  }
    82  
    83  func (s *sendStream) StreamID() protocol.StreamID {
    84  	return s.streamID // same for receiveStream and sendStream
    85  }
    86  
    87  func (s *sendStream) Write(p []byte) (int, error) {
    88  	// Concurrent use of Write is not permitted (and doesn't make any sense),
    89  	// but sometimes people do it anyway.
    90  	// Make sure that we only execute one call at any given time to avoid hard to debug failures.
    91  	s.writeOnce <- struct{}{}
    92  	defer func() { <-s.writeOnce }()
    93  
    94  	isNewlyCompleted, n, err := s.write(p)
    95  	if isNewlyCompleted {
    96  		s.sender.onStreamCompleted(s.streamID)
    97  	}
    98  	return n, err
    99  }
   100  
   101  func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error) {
   102  	s.mutex.Lock()
   103  	defer s.mutex.Unlock()
   104  
   105  	if s.finishedWriting {
   106  		return false, 0, fmt.Errorf("write on closed stream %d", s.streamID)
   107  	}
   108  	if s.cancelWriteErr != nil {
   109  		s.cancellationFlagged = true
   110  		return s.isNewlyCompleted(), 0, s.cancelWriteErr
   111  	}
   112  	if s.closeForShutdownErr != nil {
   113  		return false, 0, s.closeForShutdownErr
   114  	}
   115  	if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
   116  		return false, 0, errDeadline
   117  	}
   118  	if len(p) == 0 {
   119  		return false, 0, nil
   120  	}
   121  
   122  	s.dataForWriting = p
   123  
   124  	var (
   125  		deadlineTimer  *utils.Timer
   126  		bytesWritten   int
   127  		notifiedSender bool
   128  	)
   129  	for {
   130  		var copied bool
   131  		var deadline time.Time
   132  		// As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame),
   133  		// which can then be popped the next time we assemble a packet.
   134  		// This allows us to return Write() when all data but x bytes have been sent out.
   135  		// When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame,
   136  		// allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN).
   137  		if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 {
   138  			if s.nextFrame == nil {
   139  				f := wire.GetStreamFrame()
   140  				f.Offset = s.writeOffset
   141  				f.StreamID = s.streamID
   142  				f.DataLenPresent = true
   143  				f.Data = f.Data[:len(s.dataForWriting)]
   144  				copy(f.Data, s.dataForWriting)
   145  				s.nextFrame = f
   146  			} else {
   147  				l := len(s.nextFrame.Data)
   148  				s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)]
   149  				copy(s.nextFrame.Data[l:], s.dataForWriting)
   150  			}
   151  			s.dataForWriting = nil
   152  			bytesWritten = len(p)
   153  			copied = true
   154  		} else {
   155  			bytesWritten = len(p) - len(s.dataForWriting)
   156  			deadline = s.deadline
   157  			if !deadline.IsZero() {
   158  				if !time.Now().Before(deadline) {
   159  					s.dataForWriting = nil
   160  					return false, bytesWritten, errDeadline
   161  				}
   162  				if deadlineTimer == nil {
   163  					deadlineTimer = utils.NewTimer()
   164  					defer deadlineTimer.Stop()
   165  				}
   166  				deadlineTimer.Reset(deadline)
   167  			}
   168  			if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
   169  				break
   170  			}
   171  		}
   172  
   173  		s.mutex.Unlock()
   174  		if !notifiedSender {
   175  			s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
   176  			notifiedSender = true
   177  		}
   178  		if copied {
   179  			s.mutex.Lock()
   180  			break
   181  		}
   182  		if deadline.IsZero() {
   183  			<-s.writeChan
   184  		} else {
   185  			select {
   186  			case <-s.writeChan:
   187  			case <-deadlineTimer.Chan():
   188  				deadlineTimer.SetRead()
   189  			}
   190  		}
   191  		s.mutex.Lock()
   192  	}
   193  
   194  	if bytesWritten == len(p) {
   195  		return false, bytesWritten, nil
   196  	}
   197  	if s.closeForShutdownErr != nil {
   198  		return false, bytesWritten, s.closeForShutdownErr
   199  	} else if s.cancelWriteErr != nil {
   200  		s.cancellationFlagged = true
   201  		return s.isNewlyCompleted(), bytesWritten, s.cancelWriteErr
   202  	}
   203  	return false, bytesWritten, nil
   204  }
   205  
   206  func (s *sendStream) canBufferStreamFrame() bool {
   207  	var l protocol.ByteCount
   208  	if s.nextFrame != nil {
   209  		l = s.nextFrame.DataLen()
   210  	}
   211  	return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize
   212  }
   213  
   214  // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
   215  // maxBytes is the maximum length this frame (including frame header) will have.
   216  func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) {
   217  	s.mutex.Lock()
   218  	f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
   219  	if f != nil {
   220  		s.numOutstandingFrames++
   221  	}
   222  	s.mutex.Unlock()
   223  
   224  	if f == nil {
   225  		return ackhandler.StreamFrame{}, false, hasMoreData
   226  	}
   227  	return ackhandler.StreamFrame{
   228  		Frame:   f,
   229  		Handler: (*sendStreamAckHandler)(s),
   230  	}, true, hasMoreData
   231  }
   232  
   233  func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more data to send */) {
   234  	if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
   235  		return nil, false
   236  	}
   237  
   238  	if len(s.retransmissionQueue) > 0 {
   239  		f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v)
   240  		if f != nil || hasMoreRetransmissions {
   241  			if f == nil {
   242  				return nil, true
   243  			}
   244  			// We always claim that we have more data to send.
   245  			// This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
   246  			return f, true
   247  		}
   248  	}
   249  
   250  	if len(s.dataForWriting) == 0 && s.nextFrame == nil {
   251  		if s.finishedWriting && !s.finSent {
   252  			s.finSent = true
   253  			return &wire.StreamFrame{
   254  				StreamID:       s.streamID,
   255  				Offset:         s.writeOffset,
   256  				DataLenPresent: true,
   257  				Fin:            true,
   258  			}, false
   259  		}
   260  		return nil, false
   261  	}
   262  
   263  	sendWindow := s.flowController.SendWindowSize()
   264  	if sendWindow == 0 {
   265  		if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
   266  			s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
   267  				StreamID:          s.streamID,
   268  				MaximumStreamData: offset,
   269  			})
   270  			return nil, false
   271  		}
   272  		return nil, true
   273  	}
   274  
   275  	f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow, v)
   276  	if dataLen := f.DataLen(); dataLen > 0 {
   277  		s.writeOffset += f.DataLen()
   278  		s.flowController.AddBytesSent(f.DataLen())
   279  	}
   280  	f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent
   281  	if f.Fin {
   282  		s.finSent = true
   283  	}
   284  	return f, hasMoreData
   285  }
   286  
   287  func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) {
   288  	if s.nextFrame != nil {
   289  		nextFrame := s.nextFrame
   290  		s.nextFrame = nil
   291  
   292  		maxDataLen := utils.Min(sendWindow, nextFrame.MaxDataLen(maxBytes, v))
   293  		if nextFrame.DataLen() > maxDataLen {
   294  			s.nextFrame = wire.GetStreamFrame()
   295  			s.nextFrame.StreamID = s.streamID
   296  			s.nextFrame.Offset = s.writeOffset + maxDataLen
   297  			s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen]
   298  			s.nextFrame.DataLenPresent = true
   299  			copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:])
   300  			nextFrame.Data = nextFrame.Data[:maxDataLen]
   301  		} else {
   302  			s.signalWrite()
   303  		}
   304  		return nextFrame, s.nextFrame != nil || s.dataForWriting != nil
   305  	}
   306  
   307  	f := wire.GetStreamFrame()
   308  	f.Fin = false
   309  	f.StreamID = s.streamID
   310  	f.Offset = s.writeOffset
   311  	f.DataLenPresent = true
   312  	f.Data = f.Data[:0]
   313  
   314  	hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow, v)
   315  	if len(f.Data) == 0 && !f.Fin {
   316  		f.PutBack()
   317  		return nil, hasMoreData
   318  	}
   319  	return f, hasMoreData
   320  }
   321  
   322  func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool {
   323  	maxDataLen := f.MaxDataLen(maxBytes, v)
   324  	if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
   325  		return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
   326  	}
   327  	s.getDataForWriting(f, utils.Min(maxDataLen, sendWindow))
   328  
   329  	return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
   330  }
   331  
   332  func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) {
   333  	f := s.retransmissionQueue[0]
   334  	newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v)
   335  	if needsSplit {
   336  		return newFrame, true
   337  	}
   338  	s.retransmissionQueue = s.retransmissionQueue[1:]
   339  	return f, len(s.retransmissionQueue) > 0
   340  }
   341  
   342  func (s *sendStream) hasData() bool {
   343  	s.mutex.Lock()
   344  	hasData := len(s.dataForWriting) > 0
   345  	s.mutex.Unlock()
   346  	return hasData
   347  }
   348  
   349  func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
   350  	if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes {
   351  		f.Data = f.Data[:len(s.dataForWriting)]
   352  		copy(f.Data, s.dataForWriting)
   353  		s.dataForWriting = nil
   354  		s.signalWrite()
   355  		return
   356  	}
   357  	f.Data = f.Data[:maxBytes]
   358  	copy(f.Data, s.dataForWriting)
   359  	s.dataForWriting = s.dataForWriting[maxBytes:]
   360  	if s.canBufferStreamFrame() {
   361  		s.signalWrite()
   362  	}
   363  }
   364  
   365  func (s *sendStream) isNewlyCompleted() bool {
   366  	if s.completed {
   367  		return false
   368  	}
   369  	// We need to keep the stream around until all frames have been sent and acknowledged.
   370  	if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 {
   371  		return false
   372  	}
   373  	// The stream is completed if we sent the FIN.
   374  	if s.finSent {
   375  		s.completed = true
   376  		return true
   377  	}
   378  	// The stream is also completed if:
   379  	// 1. the application called CancelWrite, or
   380  	// 2. we received a STOP_SENDING, and
   381  	// 		* the application consumed the error via Write, or
   382  	//		* the application called CLsoe
   383  	if s.cancelWriteErr != nil && (s.cancellationFlagged || s.finishedWriting) {
   384  		s.completed = true
   385  		return true
   386  	}
   387  	return false
   388  }
   389  
   390  func (s *sendStream) Close() error {
   391  	s.mutex.Lock()
   392  	if s.closeForShutdownErr != nil {
   393  		s.mutex.Unlock()
   394  		return nil
   395  	}
   396  	s.finishedWriting = true
   397  	cancelWriteErr := s.cancelWriteErr
   398  	if cancelWriteErr != nil {
   399  		s.cancellationFlagged = true
   400  	}
   401  	completed := s.isNewlyCompleted()
   402  	s.mutex.Unlock()
   403  
   404  	if completed {
   405  		s.sender.onStreamCompleted(s.streamID)
   406  	}
   407  	if cancelWriteErr != nil {
   408  		return fmt.Errorf("close called for canceled stream %d", s.streamID)
   409  	}
   410  	s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
   411  
   412  	s.ctxCancel(nil)
   413  	return nil
   414  }
   415  
   416  func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
   417  	s.cancelWriteImpl(errorCode, false)
   418  }
   419  
   420  func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
   421  	s.mutex.Lock()
   422  	if !remote {
   423  		s.cancellationFlagged = true
   424  	}
   425  	if s.cancelWriteErr != nil {
   426  		s.mutex.Unlock()
   427  		return
   428  	}
   429  	s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
   430  	s.ctxCancel(s.cancelWriteErr)
   431  	s.numOutstandingFrames = 0
   432  	s.retransmissionQueue = nil
   433  	newlyCompleted := s.isNewlyCompleted()
   434  	s.mutex.Unlock()
   435  
   436  	s.signalWrite()
   437  	s.sender.queueControlFrame(&wire.ResetStreamFrame{
   438  		StreamID:  s.streamID,
   439  		FinalSize: s.writeOffset,
   440  		ErrorCode: errorCode,
   441  	})
   442  	if newlyCompleted {
   443  		s.sender.onStreamCompleted(s.streamID)
   444  	}
   445  }
   446  
   447  func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
   448  	updated := s.flowController.UpdateSendWindow(limit)
   449  	if !updated { // duplicate or reordered MAX_STREAM_DATA frame
   450  		return
   451  	}
   452  	s.mutex.Lock()
   453  	hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
   454  	s.mutex.Unlock()
   455  	if hasStreamData {
   456  		s.sender.onHasStreamData(s.streamID)
   457  	}
   458  }
   459  
   460  func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
   461  	s.cancelWriteImpl(frame.ErrorCode, true)
   462  }
   463  
   464  func (s *sendStream) Context() context.Context {
   465  	return s.ctx
   466  }
   467  
   468  func (s *sendStream) SetWriteDeadline(t time.Time) error {
   469  	s.mutex.Lock()
   470  	s.deadline = t
   471  	s.mutex.Unlock()
   472  	s.signalWrite()
   473  	return nil
   474  }
   475  
   476  // CloseForShutdown closes a stream abruptly.
   477  // It makes Write unblock (and return the error) immediately.
   478  // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
   479  func (s *sendStream) closeForShutdown(err error) {
   480  	s.mutex.Lock()
   481  	s.ctxCancel(err)
   482  	s.closeForShutdownErr = err
   483  	s.mutex.Unlock()
   484  	s.signalWrite()
   485  }
   486  
   487  // signalWrite performs a non-blocking send on the writeChan
   488  func (s *sendStream) signalWrite() {
   489  	select {
   490  	case s.writeChan <- struct{}{}:
   491  	default:
   492  	}
   493  }
   494  
   495  type sendStreamAckHandler sendStream
   496  
   497  var _ ackhandler.FrameHandler = &sendStreamAckHandler{}
   498  
   499  func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
   500  	sf := f.(*wire.StreamFrame)
   501  	sf.PutBack()
   502  	s.mutex.Lock()
   503  	if s.cancelWriteErr != nil {
   504  		s.mutex.Unlock()
   505  		return
   506  	}
   507  	s.numOutstandingFrames--
   508  	if s.numOutstandingFrames < 0 {
   509  		panic("numOutStandingFrames negative")
   510  	}
   511  	newlyCompleted := (*sendStream)(s).isNewlyCompleted()
   512  	s.mutex.Unlock()
   513  
   514  	if newlyCompleted {
   515  		s.sender.onStreamCompleted(s.streamID)
   516  	}
   517  }
   518  
   519  func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
   520  	sf := f.(*wire.StreamFrame)
   521  	s.mutex.Lock()
   522  	if s.cancelWriteErr != nil {
   523  		s.mutex.Unlock()
   524  		return
   525  	}
   526  	sf.DataLenPresent = true
   527  	s.retransmissionQueue = append(s.retransmissionQueue, sf)
   528  	s.numOutstandingFrames--
   529  	if s.numOutstandingFrames < 0 {
   530  		panic("numOutStandingFrames negative")
   531  	}
   532  	s.mutex.Unlock()
   533  
   534  	s.sender.onHasStreamData(s.streamID)
   535  }