github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/send_stream.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/daeuniverse/quic-go/internal/ackhandler"
    10  	"github.com/daeuniverse/quic-go/internal/flowcontrol"
    11  	"github.com/daeuniverse/quic-go/internal/protocol"
    12  	"github.com/daeuniverse/quic-go/internal/qerr"
    13  	"github.com/daeuniverse/quic-go/internal/utils"
    14  	"github.com/daeuniverse/quic-go/internal/wire"
    15  )
    16  
    17  type sendStreamI interface {
    18  	SendStream
    19  	handleStopSendingFrame(*wire.StopSendingFrame)
    20  	hasData() bool
    21  	popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (frame ackhandler.StreamFrame, ok, hasMore bool)
    22  	closeForShutdown(error)
    23  	updateSendWindow(protocol.ByteCount)
    24  }
    25  
    26  type sendStream struct {
    27  	mutex sync.Mutex
    28  
    29  	numOutstandingFrames int64
    30  	retransmissionQueue  []*wire.StreamFrame
    31  
    32  	ctx       context.Context
    33  	ctxCancel context.CancelCauseFunc
    34  
    35  	streamID protocol.StreamID
    36  	sender   streamSender
    37  
    38  	writeOffset protocol.ByteCount
    39  
    40  	cancelWriteErr      error
    41  	closeForShutdownErr error
    42  
    43  	finishedWriting bool // set once Close() is called
    44  	finSent         bool // set when a STREAM_FRAME with FIN bit has been sent
    45  	completed       bool // set when this stream has been reported to the streamSender as completed
    46  
    47  	dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
    48  	nextFrame      *wire.StreamFrame
    49  
    50  	writeChan chan struct{}
    51  	writeOnce chan struct{}
    52  	deadline  time.Time
    53  
    54  	flowController flowcontrol.StreamFlowController
    55  }
    56  
    57  var (
    58  	_ SendStream  = &sendStream{}
    59  	_ sendStreamI = &sendStream{}
    60  )
    61  
    62  func newSendStream(
    63  	streamID protocol.StreamID,
    64  	sender streamSender,
    65  	flowController flowcontrol.StreamFlowController,
    66  ) *sendStream {
    67  	s := &sendStream{
    68  		streamID:       streamID,
    69  		sender:         sender,
    70  		flowController: flowController,
    71  		writeChan:      make(chan struct{}, 1),
    72  		writeOnce:      make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write
    73  	}
    74  	s.ctx, s.ctxCancel = context.WithCancelCause(context.Background())
    75  	return s
    76  }
    77  
    78  func (s *sendStream) StreamID() protocol.StreamID {
    79  	return s.streamID // same for receiveStream and sendStream
    80  }
    81  
    82  func (s *sendStream) Write(p []byte) (int, error) {
    83  	// Concurrent use of Write is not permitted (and doesn't make any sense),
    84  	// but sometimes people do it anyway.
    85  	// Make sure that we only execute one call at any given time to avoid hard to debug failures.
    86  	s.writeOnce <- struct{}{}
    87  	defer func() { <-s.writeOnce }()
    88  
    89  	s.mutex.Lock()
    90  	defer s.mutex.Unlock()
    91  
    92  	if s.finishedWriting {
    93  		return 0, fmt.Errorf("write on closed stream %d", s.streamID)
    94  	}
    95  	if s.cancelWriteErr != nil {
    96  		return 0, s.cancelWriteErr
    97  	}
    98  	if s.closeForShutdownErr != nil {
    99  		return 0, s.closeForShutdownErr
   100  	}
   101  	if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
   102  		return 0, errDeadline
   103  	}
   104  	if len(p) == 0 {
   105  		return 0, nil
   106  	}
   107  
   108  	s.dataForWriting = p
   109  
   110  	var (
   111  		deadlineTimer  *utils.Timer
   112  		bytesWritten   int
   113  		notifiedSender bool
   114  	)
   115  	for {
   116  		var copied bool
   117  		var deadline time.Time
   118  		// As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame),
   119  		// which can then be popped the next time we assemble a packet.
   120  		// This allows us to return Write() when all data but x bytes have been sent out.
   121  		// When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame,
   122  		// allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN).
   123  		if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 {
   124  			if s.nextFrame == nil {
   125  				f := wire.GetStreamFrame()
   126  				f.Offset = s.writeOffset
   127  				f.StreamID = s.streamID
   128  				f.DataLenPresent = true
   129  				f.Data = f.Data[:len(s.dataForWriting)]
   130  				copy(f.Data, s.dataForWriting)
   131  				s.nextFrame = f
   132  			} else {
   133  				l := len(s.nextFrame.Data)
   134  				s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)]
   135  				copy(s.nextFrame.Data[l:], s.dataForWriting)
   136  			}
   137  			s.dataForWriting = nil
   138  			bytesWritten = len(p)
   139  			copied = true
   140  		} else {
   141  			bytesWritten = len(p) - len(s.dataForWriting)
   142  			deadline = s.deadline
   143  			if !deadline.IsZero() {
   144  				if !time.Now().Before(deadline) {
   145  					s.dataForWriting = nil
   146  					return bytesWritten, errDeadline
   147  				}
   148  				if deadlineTimer == nil {
   149  					deadlineTimer = utils.NewTimer()
   150  					defer deadlineTimer.Stop()
   151  				}
   152  				deadlineTimer.Reset(deadline)
   153  			}
   154  			if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
   155  				break
   156  			}
   157  		}
   158  
   159  		s.mutex.Unlock()
   160  		if !notifiedSender {
   161  			s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
   162  			notifiedSender = true
   163  		}
   164  		if copied {
   165  			s.mutex.Lock()
   166  			break
   167  		}
   168  		if deadline.IsZero() {
   169  			<-s.writeChan
   170  		} else {
   171  			select {
   172  			case <-s.writeChan:
   173  			case <-deadlineTimer.Chan():
   174  				deadlineTimer.SetRead()
   175  			}
   176  		}
   177  		s.mutex.Lock()
   178  	}
   179  
   180  	if bytesWritten == len(p) {
   181  		return bytesWritten, nil
   182  	}
   183  	if s.closeForShutdownErr != nil {
   184  		return bytesWritten, s.closeForShutdownErr
   185  	} else if s.cancelWriteErr != nil {
   186  		return bytesWritten, s.cancelWriteErr
   187  	}
   188  	return bytesWritten, nil
   189  }
   190  
   191  func (s *sendStream) canBufferStreamFrame() bool {
   192  	var l protocol.ByteCount
   193  	if s.nextFrame != nil {
   194  		l = s.nextFrame.DataLen()
   195  	}
   196  	return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize
   197  }
   198  
   199  // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
   200  // maxBytes is the maximum length this frame (including frame header) will have.
   201  func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) {
   202  	s.mutex.Lock()
   203  	f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
   204  	if f != nil {
   205  		s.numOutstandingFrames++
   206  	}
   207  	s.mutex.Unlock()
   208  
   209  	if f == nil {
   210  		return ackhandler.StreamFrame{}, false, hasMoreData
   211  	}
   212  	return ackhandler.StreamFrame{
   213  		Frame:   f,
   214  		Handler: (*sendStreamAckHandler)(s),
   215  	}, true, hasMoreData
   216  }
   217  
   218  func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more data to send */) {
   219  	if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
   220  		return nil, false
   221  	}
   222  
   223  	if len(s.retransmissionQueue) > 0 {
   224  		f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v)
   225  		if f != nil || hasMoreRetransmissions {
   226  			if f == nil {
   227  				return nil, true
   228  			}
   229  			// We always claim that we have more data to send.
   230  			// This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
   231  			return f, true
   232  		}
   233  	}
   234  
   235  	if len(s.dataForWriting) == 0 && s.nextFrame == nil {
   236  		if s.finishedWriting && !s.finSent {
   237  			s.finSent = true
   238  			return &wire.StreamFrame{
   239  				StreamID:       s.streamID,
   240  				Offset:         s.writeOffset,
   241  				DataLenPresent: true,
   242  				Fin:            true,
   243  			}, false
   244  		}
   245  		return nil, false
   246  	}
   247  
   248  	sendWindow := s.flowController.SendWindowSize()
   249  	if sendWindow == 0 {
   250  		if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
   251  			s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
   252  				StreamID:          s.streamID,
   253  				MaximumStreamData: offset,
   254  			})
   255  			return nil, false
   256  		}
   257  		return nil, true
   258  	}
   259  
   260  	f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow, v)
   261  	if dataLen := f.DataLen(); dataLen > 0 {
   262  		s.writeOffset += f.DataLen()
   263  		s.flowController.AddBytesSent(f.DataLen())
   264  	}
   265  	f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent
   266  	if f.Fin {
   267  		s.finSent = true
   268  	}
   269  	return f, hasMoreData
   270  }
   271  
   272  func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) {
   273  	if s.nextFrame != nil {
   274  		nextFrame := s.nextFrame
   275  		s.nextFrame = nil
   276  
   277  		maxDataLen := min(sendWindow, nextFrame.MaxDataLen(maxBytes, v))
   278  		if nextFrame.DataLen() > maxDataLen {
   279  			s.nextFrame = wire.GetStreamFrame()
   280  			s.nextFrame.StreamID = s.streamID
   281  			s.nextFrame.Offset = s.writeOffset + maxDataLen
   282  			s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen]
   283  			s.nextFrame.DataLenPresent = true
   284  			copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:])
   285  			nextFrame.Data = nextFrame.Data[:maxDataLen]
   286  		} else {
   287  			s.signalWrite()
   288  		}
   289  		return nextFrame, s.nextFrame != nil || s.dataForWriting != nil
   290  	}
   291  
   292  	f := wire.GetStreamFrame()
   293  	f.Fin = false
   294  	f.StreamID = s.streamID
   295  	f.Offset = s.writeOffset
   296  	f.DataLenPresent = true
   297  	f.Data = f.Data[:0]
   298  
   299  	hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow, v)
   300  	if len(f.Data) == 0 && !f.Fin {
   301  		f.PutBack()
   302  		return nil, hasMoreData
   303  	}
   304  	return f, hasMoreData
   305  }
   306  
   307  func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool {
   308  	maxDataLen := f.MaxDataLen(maxBytes, v)
   309  	if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
   310  		return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
   311  	}
   312  	s.getDataForWriting(f, min(maxDataLen, sendWindow))
   313  
   314  	return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
   315  }
   316  
   317  func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) {
   318  	f := s.retransmissionQueue[0]
   319  	newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v)
   320  	if needsSplit {
   321  		return newFrame, true
   322  	}
   323  	s.retransmissionQueue = s.retransmissionQueue[1:]
   324  	return f, len(s.retransmissionQueue) > 0
   325  }
   326  
   327  func (s *sendStream) hasData() bool {
   328  	s.mutex.Lock()
   329  	hasData := len(s.dataForWriting) > 0
   330  	s.mutex.Unlock()
   331  	return hasData
   332  }
   333  
   334  func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
   335  	if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes {
   336  		f.Data = f.Data[:len(s.dataForWriting)]
   337  		copy(f.Data, s.dataForWriting)
   338  		s.dataForWriting = nil
   339  		s.signalWrite()
   340  		return
   341  	}
   342  	f.Data = f.Data[:maxBytes]
   343  	copy(f.Data, s.dataForWriting)
   344  	s.dataForWriting = s.dataForWriting[maxBytes:]
   345  	if s.canBufferStreamFrame() {
   346  		s.signalWrite()
   347  	}
   348  }
   349  
   350  func (s *sendStream) isNewlyCompleted() bool {
   351  	completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
   352  	if completed && !s.completed {
   353  		s.completed = true
   354  		return true
   355  	}
   356  	return false
   357  }
   358  
   359  func (s *sendStream) Close() error {
   360  	s.mutex.Lock()
   361  	if s.closeForShutdownErr != nil {
   362  		s.mutex.Unlock()
   363  		return nil
   364  	}
   365  	if s.cancelWriteErr != nil {
   366  		s.mutex.Unlock()
   367  		return fmt.Errorf("close called for canceled stream %d", s.streamID)
   368  	}
   369  	s.ctxCancel(nil)
   370  	s.finishedWriting = true
   371  	s.mutex.Unlock()
   372  
   373  	s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
   374  	return nil
   375  }
   376  
   377  func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
   378  	s.cancelWriteImpl(errorCode, false)
   379  }
   380  
   381  // must be called after locking the mutex
   382  func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
   383  	s.mutex.Lock()
   384  	if s.cancelWriteErr != nil {
   385  		s.mutex.Unlock()
   386  		return
   387  	}
   388  	s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
   389  	s.ctxCancel(s.cancelWriteErr)
   390  	s.numOutstandingFrames = 0
   391  	s.retransmissionQueue = nil
   392  	newlyCompleted := s.isNewlyCompleted()
   393  	s.mutex.Unlock()
   394  
   395  	s.signalWrite()
   396  	s.sender.queueControlFrame(&wire.ResetStreamFrame{
   397  		StreamID:  s.streamID,
   398  		FinalSize: s.writeOffset,
   399  		ErrorCode: errorCode,
   400  	})
   401  	if newlyCompleted {
   402  		s.sender.onStreamCompleted(s.streamID)
   403  	}
   404  }
   405  
   406  func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
   407  	updated := s.flowController.UpdateSendWindow(limit)
   408  	if !updated { // duplicate or reordered MAX_STREAM_DATA frame
   409  		return
   410  	}
   411  	s.mutex.Lock()
   412  	hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
   413  	s.mutex.Unlock()
   414  	if hasStreamData {
   415  		s.sender.onHasStreamData(s.streamID)
   416  	}
   417  }
   418  
   419  func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
   420  	s.cancelWriteImpl(frame.ErrorCode, true)
   421  }
   422  
   423  func (s *sendStream) Context() context.Context {
   424  	return s.ctx
   425  }
   426  
   427  func (s *sendStream) SetWriteDeadline(t time.Time) error {
   428  	s.mutex.Lock()
   429  	s.deadline = t
   430  	s.mutex.Unlock()
   431  	s.signalWrite()
   432  	return nil
   433  }
   434  
   435  // CloseForShutdown closes a stream abruptly.
   436  // It makes Write unblock (and return the error) immediately.
   437  // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
   438  func (s *sendStream) closeForShutdown(err error) {
   439  	s.mutex.Lock()
   440  	s.ctxCancel(err)
   441  	s.closeForShutdownErr = err
   442  	s.mutex.Unlock()
   443  	s.signalWrite()
   444  }
   445  
   446  // signalWrite performs a non-blocking send on the writeChan
   447  func (s *sendStream) signalWrite() {
   448  	select {
   449  	case s.writeChan <- struct{}{}:
   450  	default:
   451  	}
   452  }
   453  
   454  type sendStreamAckHandler sendStream
   455  
   456  var _ ackhandler.FrameHandler = &sendStreamAckHandler{}
   457  
   458  func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
   459  	sf := f.(*wire.StreamFrame)
   460  	sf.PutBack()
   461  	s.mutex.Lock()
   462  	if s.cancelWriteErr != nil {
   463  		s.mutex.Unlock()
   464  		return
   465  	}
   466  	s.numOutstandingFrames--
   467  	if s.numOutstandingFrames < 0 {
   468  		panic("numOutStandingFrames negative")
   469  	}
   470  	newlyCompleted := (*sendStream)(s).isNewlyCompleted()
   471  	s.mutex.Unlock()
   472  
   473  	if newlyCompleted {
   474  		s.sender.onStreamCompleted(s.streamID)
   475  	}
   476  }
   477  
   478  func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
   479  	sf := f.(*wire.StreamFrame)
   480  	s.mutex.Lock()
   481  	if s.cancelWriteErr != nil {
   482  		s.mutex.Unlock()
   483  		return
   484  	}
   485  	sf.DataLenPresent = true
   486  	s.retransmissionQueue = append(s.retransmissionQueue, sf)
   487  	s.numOutstandingFrames--
   488  	if s.numOutstandingFrames < 0 {
   489  		panic("numOutStandingFrames negative")
   490  	}
   491  	s.mutex.Unlock()
   492  
   493  	s.sender.onHasStreamData(s.streamID)
   494  }