github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/send_stream.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/ackhandler"
    10  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/flowcontrol"
    11  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    12  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr"
    13  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/utils"
    14  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/wire"
    15  	"github.com/danielpfeifer02/quic-go-prio-packs/priority_setting"
    16  )
    17  
    18  type sendStreamI interface {
    19  	SendStream
    20  	handleStopSendingFrame(*wire.StopSendingFrame)
    21  	hasData() bool
    22  	popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (frame ackhandler.StreamFrame, ok, hasMore bool)
    23  	closeForShutdown(error)
    24  	updateSendWindow(protocol.ByteCount)
    25  }
    26  
    27  type sendStream struct {
    28  	mutex sync.Mutex
    29  
    30  	numOutstandingFrames int64
    31  	retransmissionQueue  []*wire.StreamFrame
    32  
    33  	ctx       context.Context
    34  	ctxCancel context.CancelCauseFunc
    35  
    36  	streamID protocol.StreamID
    37  	sender   streamSender
    38  
    39  	writeOffset protocol.ByteCount
    40  
    41  	cancelWriteErr      error
    42  	closeForShutdownErr error
    43  
    44  	finishedWriting bool // set once Close() is called
    45  	finSent         bool // set when a STREAM_FRAME with FIN bit has been sent
    46  	completed       bool // set when this stream has been reported to the streamSender as completed
    47  
    48  	dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
    49  	nextFrame      *wire.StreamFrame
    50  
    51  	writeChan chan struct{}
    52  	writeOnce chan struct{}
    53  	deadline  time.Time
    54  
    55  	flowController flowcontrol.StreamFlowController
    56  
    57  	// PRIO_PACKS_TAG
    58  	priority priority_setting.Priority
    59  }
    60  
    61  var (
    62  	_ SendStream  = &sendStream{}
    63  	_ sendStreamI = &sendStream{}
    64  )
    65  
    66  func newSendStream(
    67  	streamID protocol.StreamID,
    68  	sender streamSender,
    69  	flowController flowcontrol.StreamFlowController,
    70  ) *sendStream {
    71  	s := &sendStream{
    72  		streamID:       streamID,
    73  		sender:         sender,
    74  		flowController: flowController,
    75  		writeChan:      make(chan struct{}, 1),
    76  		writeOnce:      make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write
    77  		// PRIO_PACKS_TAG
    78  		priority: priority_setting.NoPriority,
    79  	}
    80  	s.ctx, s.ctxCancel = context.WithCancelCause(context.Background())
    81  	return s
    82  }
    83  
    84  func (s *sendStream) StreamID() protocol.StreamID {
    85  	return s.streamID // same for receiveStream and sendStream
    86  }
    87  
    88  func (s *sendStream) Write(p []byte) (int, error) {
    89  	// Concurrent use of Write is not permitted (and doesn't make any sense),
    90  	// but sometimes people do it anyway.
    91  	// Make sure that we only execute one call at any given time to avoid hard to debug failures.
    92  	s.writeOnce <- struct{}{}
    93  	defer func() { <-s.writeOnce }()
    94  
    95  	s.mutex.Lock()
    96  	defer s.mutex.Unlock()
    97  
    98  	if s.finishedWriting {
    99  		return 0, fmt.Errorf("write on closed stream %d", s.streamID)
   100  	}
   101  	if s.cancelWriteErr != nil {
   102  		return 0, s.cancelWriteErr
   103  	}
   104  	if s.closeForShutdownErr != nil {
   105  		return 0, s.closeForShutdownErr
   106  	}
   107  	if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
   108  		return 0, errDeadline
   109  	}
   110  	if len(p) == 0 {
   111  		return 0, nil
   112  	}
   113  
   114  	s.dataForWriting = p
   115  
   116  	var (
   117  		deadlineTimer  *utils.Timer
   118  		bytesWritten   int
   119  		notifiedSender bool
   120  	)
   121  	for {
   122  		var copied bool
   123  		var deadline time.Time
   124  		// As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame),
   125  		// which can then be popped the next time we assemble a packet.
   126  		// This allows us to return Write() when all data but x bytes have been sent out.
   127  		// When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame,
   128  		// allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN).
   129  		if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 {
   130  			if s.nextFrame == nil {
   131  				f := wire.GetStreamFrame()
   132  				f.Offset = s.writeOffset
   133  				f.StreamID = s.streamID
   134  				f.DataLenPresent = true
   135  				f.Data = f.Data[:len(s.dataForWriting)]
   136  				copy(f.Data, s.dataForWriting)
   137  				s.nextFrame = f
   138  			} else {
   139  				l := len(s.nextFrame.Data)
   140  				s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)]
   141  				copy(s.nextFrame.Data[l:], s.dataForWriting)
   142  			}
   143  			s.dataForWriting = nil
   144  			bytesWritten = len(p)
   145  			copied = true
   146  		} else {
   147  			bytesWritten = len(p) - len(s.dataForWriting)
   148  			deadline = s.deadline
   149  			if !deadline.IsZero() {
   150  				if !time.Now().Before(deadline) {
   151  					s.dataForWriting = nil
   152  					return bytesWritten, errDeadline
   153  				}
   154  				if deadlineTimer == nil {
   155  					deadlineTimer = utils.NewTimer()
   156  					defer deadlineTimer.Stop()
   157  				}
   158  				deadlineTimer.Reset(deadline)
   159  			}
   160  			if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
   161  				break
   162  			}
   163  		}
   164  
   165  		s.mutex.Unlock()
   166  		if !notifiedSender {
   167  			s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
   168  			notifiedSender = true
   169  		}
   170  		if copied {
   171  			s.mutex.Lock()
   172  			break
   173  		}
   174  		if deadline.IsZero() {
   175  			<-s.writeChan
   176  		} else {
   177  			select {
   178  			case <-s.writeChan:
   179  			case <-deadlineTimer.Chan():
   180  				deadlineTimer.SetRead()
   181  			}
   182  		}
   183  		s.mutex.Lock()
   184  	}
   185  
   186  	if bytesWritten == len(p) {
   187  		return bytesWritten, nil
   188  	}
   189  	if s.closeForShutdownErr != nil {
   190  		return bytesWritten, s.closeForShutdownErr
   191  	} else if s.cancelWriteErr != nil {
   192  		return bytesWritten, s.cancelWriteErr
   193  	}
   194  	return bytesWritten, nil
   195  }
   196  
   197  func (s *sendStream) canBufferStreamFrame() bool {
   198  	var l protocol.ByteCount
   199  	if s.nextFrame != nil {
   200  		l = s.nextFrame.DataLen()
   201  	}
   202  	return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize
   203  }
   204  
   205  // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
   206  // maxBytes is the maximum length this frame (including frame header) will have.
   207  func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) {
   208  	s.mutex.Lock()
   209  	f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
   210  	if f != nil {
   211  		s.numOutstandingFrames++
   212  	}
   213  	s.mutex.Unlock()
   214  
   215  	if f == nil {
   216  		return ackhandler.StreamFrame{}, false, hasMoreData
   217  	}
   218  	return ackhandler.StreamFrame{
   219  		Frame:   f,
   220  		Handler: (*sendStreamAckHandler)(s),
   221  	}, true, hasMoreData
   222  }
   223  
   224  func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more data to send */) {
   225  	if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
   226  		return nil, false
   227  	}
   228  
   229  	if len(s.retransmissionQueue) > 0 {
   230  		f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v)
   231  		if f != nil || hasMoreRetransmissions {
   232  			if f == nil {
   233  				return nil, true
   234  			}
   235  			// We always claim that we have more data to send.
   236  			// This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
   237  			return f, true
   238  		}
   239  	}
   240  
   241  	if len(s.dataForWriting) == 0 && s.nextFrame == nil {
   242  		if s.finishedWriting && !s.finSent {
   243  			s.finSent = true
   244  			return &wire.StreamFrame{
   245  				StreamID:       s.streamID,
   246  				Offset:         s.writeOffset,
   247  				DataLenPresent: true,
   248  				Fin:            true,
   249  			}, false
   250  		}
   251  		return nil, false
   252  	}
   253  
   254  	sendWindow := s.flowController.SendWindowSize()
   255  	if sendWindow == 0 {
   256  		if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
   257  			s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
   258  				StreamID:          s.streamID,
   259  				MaximumStreamData: offset,
   260  			})
   261  			return nil, false
   262  		}
   263  		return nil, true
   264  	}
   265  
   266  	f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow, v)
   267  	if dataLen := f.DataLen(); dataLen > 0 {
   268  		s.writeOffset += f.DataLen()
   269  		s.flowController.AddBytesSent(f.DataLen())
   270  	}
   271  	f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent
   272  	if f.Fin {
   273  		s.finSent = true
   274  	}
   275  	return f, hasMoreData
   276  }
   277  
   278  func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) {
   279  	if s.nextFrame != nil {
   280  		nextFrame := s.nextFrame
   281  		s.nextFrame = nil
   282  
   283  		maxDataLen := min(sendWindow, nextFrame.MaxDataLen(maxBytes, v))
   284  		if nextFrame.DataLen() > maxDataLen {
   285  			s.nextFrame = wire.GetStreamFrame()
   286  			s.nextFrame.StreamID = s.streamID
   287  			s.nextFrame.Offset = s.writeOffset + maxDataLen
   288  			s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen]
   289  			s.nextFrame.DataLenPresent = true
   290  			copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:])
   291  			nextFrame.Data = nextFrame.Data[:maxDataLen]
   292  		} else {
   293  			s.signalWrite()
   294  		}
   295  		return nextFrame, s.nextFrame != nil || s.dataForWriting != nil
   296  	}
   297  
   298  	f := wire.GetStreamFrame()
   299  	f.Fin = false
   300  	f.StreamID = s.streamID
   301  	f.Offset = s.writeOffset
   302  	f.DataLenPresent = true
   303  	f.Data = f.Data[:0]
   304  
   305  	hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow, v)
   306  	if len(f.Data) == 0 && !f.Fin {
   307  		f.PutBack()
   308  		return nil, hasMoreData
   309  	}
   310  	return f, hasMoreData
   311  }
   312  
   313  func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool {
   314  	maxDataLen := f.MaxDataLen(maxBytes, v)
   315  	if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
   316  		return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
   317  	}
   318  	s.getDataForWriting(f, min(maxDataLen, sendWindow))
   319  
   320  	return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
   321  }
   322  
   323  func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) {
   324  	f := s.retransmissionQueue[0]
   325  	newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v)
   326  	if needsSplit {
   327  		return newFrame, true
   328  	}
   329  	s.retransmissionQueue = s.retransmissionQueue[1:]
   330  	return f, len(s.retransmissionQueue) > 0
   331  }
   332  
   333  func (s *sendStream) hasData() bool {
   334  	s.mutex.Lock()
   335  	hasData := len(s.dataForWriting) > 0
   336  	s.mutex.Unlock()
   337  	return hasData
   338  }
   339  
   340  func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
   341  	if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes {
   342  		f.Data = f.Data[:len(s.dataForWriting)]
   343  		copy(f.Data, s.dataForWriting)
   344  		s.dataForWriting = nil
   345  		s.signalWrite()
   346  		return
   347  	}
   348  	f.Data = f.Data[:maxBytes]
   349  	copy(f.Data, s.dataForWriting)
   350  	s.dataForWriting = s.dataForWriting[maxBytes:]
   351  	if s.canBufferStreamFrame() {
   352  		s.signalWrite()
   353  	}
   354  }
   355  
   356  func (s *sendStream) isNewlyCompleted() bool {
   357  	completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
   358  	if completed && !s.completed {
   359  		s.completed = true
   360  		return true
   361  	}
   362  	return false
   363  }
   364  
   365  func (s *sendStream) Close() error {
   366  	s.mutex.Lock()
   367  	if s.closeForShutdownErr != nil {
   368  		s.mutex.Unlock()
   369  		return nil
   370  	}
   371  	if s.cancelWriteErr != nil {
   372  		s.mutex.Unlock()
   373  		return fmt.Errorf("close called for canceled stream %d", s.streamID)
   374  	}
   375  	s.ctxCancel(nil)
   376  	s.finishedWriting = true
   377  	s.mutex.Unlock()
   378  
   379  	s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
   380  	return nil
   381  }
   382  
   383  func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
   384  	s.cancelWriteImpl(errorCode, false)
   385  }
   386  
   387  // must be called after locking the mutex
   388  func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
   389  	s.mutex.Lock()
   390  	if s.cancelWriteErr != nil {
   391  		s.mutex.Unlock()
   392  		return
   393  	}
   394  	s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
   395  	s.ctxCancel(s.cancelWriteErr)
   396  	s.numOutstandingFrames = 0
   397  	s.retransmissionQueue = nil
   398  	newlyCompleted := s.isNewlyCompleted()
   399  	s.mutex.Unlock()
   400  
   401  	s.signalWrite()
   402  	s.sender.queueControlFrame(&wire.ResetStreamFrame{
   403  		StreamID:  s.streamID,
   404  		FinalSize: s.writeOffset,
   405  		ErrorCode: errorCode,
   406  	})
   407  	if newlyCompleted {
   408  		s.sender.onStreamCompleted(s.streamID)
   409  	}
   410  }
   411  
   412  func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
   413  	updated := s.flowController.UpdateSendWindow(limit)
   414  	if !updated { // duplicate or reordered MAX_STREAM_DATA frame
   415  		return
   416  	}
   417  	s.mutex.Lock()
   418  	hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
   419  	s.mutex.Unlock()
   420  	if hasStreamData {
   421  		s.sender.onHasStreamData(s.streamID)
   422  	}
   423  }
   424  
   425  func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
   426  	s.cancelWriteImpl(frame.ErrorCode, true)
   427  }
   428  
   429  func (s *sendStream) Context() context.Context {
   430  	return s.ctx
   431  }
   432  
   433  func (s *sendStream) SetWriteDeadline(t time.Time) error {
   434  	s.mutex.Lock()
   435  	s.deadline = t
   436  	s.mutex.Unlock()
   437  	s.signalWrite()
   438  	return nil
   439  }
   440  
   441  // CloseForShutdown closes a stream abruptly.
   442  // It makes Write unblock (and return the error) immediately.
   443  // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
   444  func (s *sendStream) closeForShutdown(err error) {
   445  	s.mutex.Lock()
   446  	s.ctxCancel(err)
   447  	s.closeForShutdownErr = err
   448  	s.mutex.Unlock()
   449  	s.signalWrite()
   450  }
   451  
   452  // signalWrite performs a non-blocking send on the writeChan
   453  func (s *sendStream) signalWrite() {
   454  	select {
   455  	case s.writeChan <- struct{}{}:
   456  	default:
   457  	}
   458  }
   459  
   460  // PRIO_PACKS_TAG
   461  // Priority returns the priority of the stream
   462  func (s *sendStream) Priority() priority_setting.Priority {
   463  	return s.priority
   464  }
   465  
   466  // PRIO_PACKS_TAG
   467  // SetPriority sets the priority of the stream
   468  func (s *sendStream) SetPriority(priority priority_setting.Priority) {
   469  	s.priority = priority
   470  }
   471  
   472  type sendStreamAckHandler sendStream
   473  
   474  var _ ackhandler.FrameHandler = &sendStreamAckHandler{}
   475  
   476  func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
   477  	sf := f.(*wire.StreamFrame)
   478  	sf.PutBack()
   479  	s.mutex.Lock()
   480  	if s.cancelWriteErr != nil {
   481  		s.mutex.Unlock()
   482  		return
   483  	}
   484  	s.numOutstandingFrames--
   485  	if s.numOutstandingFrames < 0 {
   486  		panic("numOutStandingFrames negative")
   487  	}
   488  	newlyCompleted := (*sendStream)(s).isNewlyCompleted()
   489  	s.mutex.Unlock()
   490  
   491  	if newlyCompleted {
   492  		s.sender.onStreamCompleted(s.streamID)
   493  	}
   494  }
   495  
   496  func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
   497  	sf := f.(*wire.StreamFrame)
   498  	s.mutex.Lock()
   499  	if s.cancelWriteErr != nil {
   500  		s.mutex.Unlock()
   501  		return
   502  	}
   503  	sf.DataLenPresent = true
   504  	s.retransmissionQueue = append(s.retransmissionQueue, sf)
   505  	s.numOutstandingFrames--
   506  	if s.numOutstandingFrames < 0 {
   507  		panic("numOutStandingFrames negative")
   508  	}
   509  	s.mutex.Unlock()
   510  
   511  	s.sender.onHasStreamData(s.streamID)
   512  }