github.com/tumi8/quic-go@v0.37.4-tum/send_stream.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/tumi8/quic-go/noninternal/ackhandler"
    10  
    11  	"github.com/tumi8/quic-go/noninternal/flowcontrol"
    12  	"github.com/tumi8/quic-go/noninternal/protocol"
    13  	"github.com/tumi8/quic-go/noninternal/qerr"
    14  	"github.com/tumi8/quic-go/noninternal/utils"
    15  	"github.com/tumi8/quic-go/noninternal/wire"
    16  )
    17  
    18  type sendStreamI interface {
    19  	SendStream
    20  	handleStopSendingFrame(*wire.StopSendingFrame)
    21  	hasData() bool
    22  	popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (frame ackhandler.StreamFrame, ok, hasMore bool)
    23  	closeForShutdown(error)
    24  	updateSendWindow(protocol.ByteCount)
    25  }
    26  
    27  type sendStream struct {
    28  	mutex sync.Mutex
    29  
    30  	numOutstandingFrames int64
    31  	retransmissionQueue  []*wire.StreamFrame
    32  
    33  	ctx       context.Context
    34  	ctxCancel context.CancelCauseFunc
    35  
    36  	streamID protocol.StreamID
    37  	sender   streamSender
    38  
    39  	writeOffset protocol.ByteCount
    40  
    41  	cancelWriteErr      error
    42  	closeForShutdownErr error
    43  
    44  	finishedWriting bool // set once Close() is called
    45  	finSent         bool // set when a STREAM_FRAME with FIN bit has been sent
    46  	completed       bool // set when this stream has been reported to the streamSender as completed
    47  
    48  	dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
    49  	nextFrame      *wire.StreamFrame
    50  
    51  	writeChan chan struct{}
    52  	writeOnce chan struct{}
    53  	deadline  time.Time
    54  
    55  	flowController flowcontrol.StreamFlowController
    56  }
    57  
    58  var (
    59  	_ SendStream  = &sendStream{}
    60  	_ sendStreamI = &sendStream{}
    61  )
    62  
    63  func newSendStream(
    64  	streamID protocol.StreamID,
    65  	sender streamSender,
    66  	flowController flowcontrol.StreamFlowController,
    67  ) *sendStream {
    68  	s := &sendStream{
    69  		streamID:       streamID,
    70  		sender:         sender,
    71  		flowController: flowController,
    72  		writeChan:      make(chan struct{}, 1),
    73  		writeOnce:      make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write
    74  	}
    75  	s.ctx, s.ctxCancel = context.WithCancelCause(context.Background())
    76  	return s
    77  }
    78  
    79  func (s *sendStream) StreamID() protocol.StreamID {
    80  	return s.streamID // same for receiveStream and sendStream
    81  }
    82  
    83  func (s *sendStream) Write(p []byte) (int, error) {
    84  	// Concurrent use of Write is not permitted (and doesn't make any sense),
    85  	// but sometimes people do it anyway.
    86  	// Make sure that we only execute one call at any given time to avoid hard to debug failures.
    87  	s.writeOnce <- struct{}{}
    88  	defer func() { <-s.writeOnce }()
    89  
    90  	s.mutex.Lock()
    91  	defer s.mutex.Unlock()
    92  
    93  	if s.finishedWriting {
    94  		return 0, fmt.Errorf("write on closed stream %d", s.streamID)
    95  	}
    96  	if s.cancelWriteErr != nil {
    97  		return 0, s.cancelWriteErr
    98  	}
    99  	if s.closeForShutdownErr != nil {
   100  		return 0, s.closeForShutdownErr
   101  	}
   102  	if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
   103  		return 0, errDeadline
   104  	}
   105  	if len(p) == 0 {
   106  		return 0, nil
   107  	}
   108  
   109  	s.dataForWriting = p
   110  
   111  	var (
   112  		deadlineTimer  *utils.Timer
   113  		bytesWritten   int
   114  		notifiedSender bool
   115  	)
   116  	for {
   117  		var copied bool
   118  		var deadline time.Time
   119  		// As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame),
   120  		// which can then be popped the next time we assemble a packet.
   121  		// This allows us to return Write() when all data but x bytes have been sent out.
   122  		// When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame,
   123  		// allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN).
   124  		if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 {
   125  			if s.nextFrame == nil {
   126  				f := wire.GetStreamFrame()
   127  				f.Offset = s.writeOffset
   128  				f.StreamID = s.streamID
   129  				f.DataLenPresent = true
   130  				f.Data = f.Data[:len(s.dataForWriting)]
   131  				copy(f.Data, s.dataForWriting)
   132  				s.nextFrame = f
   133  			} else {
   134  				l := len(s.nextFrame.Data)
   135  				s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)]
   136  				copy(s.nextFrame.Data[l:], s.dataForWriting)
   137  			}
   138  			s.dataForWriting = nil
   139  			bytesWritten = len(p)
   140  			copied = true
   141  		} else {
   142  			bytesWritten = len(p) - len(s.dataForWriting)
   143  			deadline = s.deadline
   144  			if !deadline.IsZero() {
   145  				if !time.Now().Before(deadline) {
   146  					s.dataForWriting = nil
   147  					return bytesWritten, errDeadline
   148  				}
   149  				if deadlineTimer == nil {
   150  					deadlineTimer = utils.NewTimer()
   151  					defer deadlineTimer.Stop()
   152  				}
   153  				deadlineTimer.Reset(deadline)
   154  			}
   155  			if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
   156  				break
   157  			}
   158  		}
   159  
   160  		s.mutex.Unlock()
   161  		if !notifiedSender {
   162  			s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
   163  			notifiedSender = true
   164  		}
   165  		if copied {
   166  			s.mutex.Lock()
   167  			break
   168  		}
   169  		if deadline.IsZero() {
   170  			<-s.writeChan
   171  		} else {
   172  			select {
   173  			case <-s.writeChan:
   174  			case <-deadlineTimer.Chan():
   175  				deadlineTimer.SetRead()
   176  			}
   177  		}
   178  		s.mutex.Lock()
   179  	}
   180  
   181  	if bytesWritten == len(p) {
   182  		return bytesWritten, nil
   183  	}
   184  	if s.closeForShutdownErr != nil {
   185  		return bytesWritten, s.closeForShutdownErr
   186  	} else if s.cancelWriteErr != nil {
   187  		return bytesWritten, s.cancelWriteErr
   188  	}
   189  	return bytesWritten, nil
   190  }
   191  
   192  func (s *sendStream) canBufferStreamFrame() bool {
   193  	var l protocol.ByteCount
   194  	if s.nextFrame != nil {
   195  		l = s.nextFrame.DataLen()
   196  	}
   197  	return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize
   198  }
   199  
   200  // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
   201  // maxBytes is the maximum length this frame (including frame header) will have.
   202  func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (af ackhandler.StreamFrame, ok, hasMore bool) {
   203  	s.mutex.Lock()
   204  	f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
   205  	if f != nil {
   206  		s.numOutstandingFrames++
   207  	}
   208  	s.mutex.Unlock()
   209  
   210  	if f == nil {
   211  		return ackhandler.StreamFrame{}, false, hasMoreData
   212  	}
   213  	return ackhandler.StreamFrame{
   214  		Frame:   f,
   215  		Handler: (*sendStreamAckHandler)(s),
   216  	}, true, hasMoreData
   217  }
   218  
   219  func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) {
   220  	if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
   221  		return nil, false
   222  	}
   223  
   224  	if len(s.retransmissionQueue) > 0 {
   225  		f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v)
   226  		if f != nil || hasMoreRetransmissions {
   227  			if f == nil {
   228  				return nil, true
   229  			}
   230  			// We always claim that we have more data to send.
   231  			// This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
   232  			return f, true
   233  		}
   234  	}
   235  
   236  	if len(s.dataForWriting) == 0 && s.nextFrame == nil {
   237  		if s.finishedWriting && !s.finSent {
   238  			s.finSent = true
   239  			return &wire.StreamFrame{
   240  				StreamID:       s.streamID,
   241  				Offset:         s.writeOffset,
   242  				DataLenPresent: true,
   243  				Fin:            true,
   244  			}, false
   245  		}
   246  		return nil, false
   247  	}
   248  
   249  	sendWindow := s.flowController.SendWindowSize()
   250  	if sendWindow == 0 {
   251  		if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
   252  			s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
   253  				StreamID:          s.streamID,
   254  				MaximumStreamData: offset,
   255  			})
   256  			return nil, false
   257  		}
   258  		return nil, true
   259  	}
   260  
   261  	f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow, v)
   262  	if dataLen := f.DataLen(); dataLen > 0 {
   263  		s.writeOffset += f.DataLen()
   264  		s.flowController.AddBytesSent(f.DataLen())
   265  	}
   266  	f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent
   267  	if f.Fin {
   268  		s.finSent = true
   269  	}
   270  	return f, hasMoreData
   271  }
   272  
   273  func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool) {
   274  	if s.nextFrame != nil {
   275  		nextFrame := s.nextFrame
   276  		s.nextFrame = nil
   277  
   278  		maxDataLen := utils.Min(sendWindow, nextFrame.MaxDataLen(maxBytes, v))
   279  		if nextFrame.DataLen() > maxDataLen {
   280  			s.nextFrame = wire.GetStreamFrame()
   281  			s.nextFrame.StreamID = s.streamID
   282  			s.nextFrame.Offset = s.writeOffset + maxDataLen
   283  			s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen]
   284  			s.nextFrame.DataLenPresent = true
   285  			copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:])
   286  			nextFrame.Data = nextFrame.Data[:maxDataLen]
   287  		} else {
   288  			s.signalWrite()
   289  		}
   290  		return nextFrame, s.nextFrame != nil || s.dataForWriting != nil
   291  	}
   292  
   293  	f := wire.GetStreamFrame()
   294  	f.Fin = false
   295  	f.StreamID = s.streamID
   296  	f.Offset = s.writeOffset
   297  	f.DataLenPresent = true
   298  	f.Data = f.Data[:0]
   299  
   300  	hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow, v)
   301  	if len(f.Data) == 0 && !f.Fin {
   302  		f.PutBack()
   303  		return nil, hasMoreData
   304  	}
   305  	return f, hasMoreData
   306  }
   307  
   308  func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) bool {
   309  	maxDataLen := f.MaxDataLen(maxBytes, v)
   310  	if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
   311  		return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
   312  	}
   313  	s.getDataForWriting(f, utils.Min(maxDataLen, sendWindow))
   314  
   315  	return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
   316  }
   317  
   318  func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more retransmissions */) {
   319  	f := s.retransmissionQueue[0]
   320  	newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v)
   321  	if needsSplit {
   322  		return newFrame, true
   323  	}
   324  	s.retransmissionQueue = s.retransmissionQueue[1:]
   325  	return f, len(s.retransmissionQueue) > 0
   326  }
   327  
   328  func (s *sendStream) hasData() bool {
   329  	s.mutex.Lock()
   330  	hasData := len(s.dataForWriting) > 0
   331  	s.mutex.Unlock()
   332  	return hasData
   333  }
   334  
   335  func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
   336  	if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes {
   337  		f.Data = f.Data[:len(s.dataForWriting)]
   338  		copy(f.Data, s.dataForWriting)
   339  		s.dataForWriting = nil
   340  		s.signalWrite()
   341  		return
   342  	}
   343  	f.Data = f.Data[:maxBytes]
   344  	copy(f.Data, s.dataForWriting)
   345  	s.dataForWriting = s.dataForWriting[maxBytes:]
   346  	if s.canBufferStreamFrame() {
   347  		s.signalWrite()
   348  	}
   349  }
   350  
   351  func (s *sendStream) isNewlyCompleted() bool {
   352  	completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
   353  	if completed && !s.completed {
   354  		s.completed = true
   355  		return true
   356  	}
   357  	return false
   358  }
   359  
   360  func (s *sendStream) Close() error {
   361  	s.mutex.Lock()
   362  	if s.closeForShutdownErr != nil {
   363  		s.mutex.Unlock()
   364  		return nil
   365  	}
   366  	if s.cancelWriteErr != nil {
   367  		s.mutex.Unlock()
   368  		return fmt.Errorf("close called for canceled stream %d", s.streamID)
   369  	}
   370  	s.ctxCancel(nil)
   371  	s.finishedWriting = true
   372  	s.mutex.Unlock()
   373  
   374  	s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
   375  	return nil
   376  }
   377  
   378  func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
   379  	s.cancelWriteImpl(errorCode, false)
   380  }
   381  
   382  // must be called after locking the mutex
   383  func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
   384  	s.mutex.Lock()
   385  	if s.cancelWriteErr != nil {
   386  		s.mutex.Unlock()
   387  		return
   388  	}
   389  	s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
   390  	s.ctxCancel(s.cancelWriteErr)
   391  	s.numOutstandingFrames = 0
   392  	s.retransmissionQueue = nil
   393  	newlyCompleted := s.isNewlyCompleted()
   394  	s.mutex.Unlock()
   395  
   396  	s.signalWrite()
   397  	s.sender.queueControlFrame(&wire.ResetStreamFrame{
   398  		StreamID:  s.streamID,
   399  		FinalSize: s.writeOffset,
   400  		ErrorCode: errorCode,
   401  	})
   402  	if newlyCompleted {
   403  		s.sender.onStreamCompleted(s.streamID)
   404  	}
   405  }
   406  
   407  func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
   408  	s.mutex.Lock()
   409  	hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
   410  	s.mutex.Unlock()
   411  
   412  	s.flowController.UpdateSendWindow(limit)
   413  	if hasStreamData {
   414  		s.sender.onHasStreamData(s.streamID)
   415  	}
   416  }
   417  
   418  func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
   419  	s.cancelWriteImpl(frame.ErrorCode, true)
   420  }
   421  
   422  func (s *sendStream) Context() context.Context {
   423  	return s.ctx
   424  }
   425  
   426  func (s *sendStream) SetWriteDeadline(t time.Time) error {
   427  	s.mutex.Lock()
   428  	s.deadline = t
   429  	s.mutex.Unlock()
   430  	s.signalWrite()
   431  	return nil
   432  }
   433  
   434  // CloseForShutdown closes a stream abruptly.
   435  // It makes Write unblock (and return the error) immediately.
   436  // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
   437  func (s *sendStream) closeForShutdown(err error) {
   438  	s.mutex.Lock()
   439  	s.ctxCancel(err)
   440  	s.closeForShutdownErr = err
   441  	s.mutex.Unlock()
   442  	s.signalWrite()
   443  }
   444  
   445  // signalWrite performs a non-blocking send on the writeChan
   446  func (s *sendStream) signalWrite() {
   447  	select {
   448  	case s.writeChan <- struct{}{}:
   449  	default:
   450  	}
   451  }
   452  
   453  type sendStreamAckHandler sendStream
   454  
   455  var _ ackhandler.FrameHandler = &sendStreamAckHandler{}
   456  
   457  func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
   458  	sf := f.(*wire.StreamFrame)
   459  	sf.PutBack()
   460  	s.mutex.Lock()
   461  	if s.cancelWriteErr != nil {
   462  		s.mutex.Unlock()
   463  		return
   464  	}
   465  	s.numOutstandingFrames--
   466  	if s.numOutstandingFrames < 0 {
   467  		panic("numOutStandingFrames negative")
   468  	}
   469  	newlyCompleted := (*sendStream)(s).isNewlyCompleted()
   470  	s.mutex.Unlock()
   471  
   472  	if newlyCompleted {
   473  		s.sender.onStreamCompleted(s.streamID)
   474  	}
   475  }
   476  
   477  func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
   478  	sf := f.(*wire.StreamFrame)
   479  	s.mutex.Lock()
   480  	if s.cancelWriteErr != nil {
   481  		s.mutex.Unlock()
   482  		return
   483  	}
   484  	sf.DataLenPresent = true
   485  	s.retransmissionQueue = append(s.retransmissionQueue, sf)
   486  	s.numOutstandingFrames--
   487  	if s.numOutstandingFrames < 0 {
   488  		panic("numOutStandingFrames negative")
   489  	}
   490  	s.mutex.Unlock()
   491  
   492  	s.sender.onHasStreamData(s.streamID)
   493  }