github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/send_stream.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/ackhandler"
    10  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/flowcontrol"
    11  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol"
    12  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/qerr"
    13  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils"
    14  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/wire"
    15  )
    16  
    17  type sendStreamI interface {
    18  	SendStream
    19  	handleStopSendingFrame(*wire.StopSendingFrame)
    20  	hasData() bool
    21  	popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool)
    22  	closeForShutdown(error)
    23  	updateSendWindow(protocol.ByteCount)
    24  }
    25  
    26  type sendStream struct {
    27  	mutex sync.Mutex
    28  
    29  	numOutstandingFrames int64
    30  	retransmissionQueue  []*wire.StreamFrame
    31  
    32  	ctx       context.Context
    33  	ctxCancel context.CancelFunc
    34  
    35  	streamID protocol.StreamID
    36  	sender   streamSender
    37  
    38  	writeOffset protocol.ByteCount
    39  
    40  	cancelWriteErr      error
    41  	closeForShutdownErr error
    42  
    43  	closedForShutdown bool // set when CloseForShutdown() is called
    44  	finishedWriting   bool // set once Close() is called
    45  	canceledWrite     bool // set when CancelWrite() is called, or a STOP_SENDING frame is received
    46  	finSent           bool // set when a STREAM_FRAME with FIN bit has been sent
    47  	completed         bool // set when this stream has been reported to the streamSender as completed
    48  
    49  	dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
    50  	nextFrame      *wire.StreamFrame
    51  
    52  	writeChan chan struct{}
    53  	deadline  time.Time
    54  
    55  	flowController flowcontrol.StreamFlowController
    56  
    57  	version protocol.VersionNumber
    58  }
    59  
    60  var (
    61  	_ SendStream  = &sendStream{}
    62  	_ sendStreamI = &sendStream{}
    63  )
    64  
    65  func newSendStream(
    66  	streamID protocol.StreamID,
    67  	sender streamSender,
    68  	flowController flowcontrol.StreamFlowController,
    69  	version protocol.VersionNumber,
    70  ) *sendStream {
    71  	s := &sendStream{
    72  		streamID:       streamID,
    73  		sender:         sender,
    74  		flowController: flowController,
    75  		writeChan:      make(chan struct{}, 1),
    76  		version:        version,
    77  	}
    78  	s.ctx, s.ctxCancel = context.WithCancel(context.Background())
    79  	return s
    80  }
    81  
    82  func (s *sendStream) StreamID() protocol.StreamID {
    83  	return s.streamID // same for receiveStream and sendStream
    84  }
    85  
    86  func (s *sendStream) Write(p []byte) (int, error) {
    87  	s.mutex.Lock()
    88  	defer s.mutex.Unlock()
    89  
    90  	if s.finishedWriting {
    91  		return 0, fmt.Errorf("write on closed stream %d", s.streamID)
    92  	}
    93  	if s.canceledWrite {
    94  		return 0, s.cancelWriteErr
    95  	}
    96  	if s.closeForShutdownErr != nil {
    97  		return 0, s.closeForShutdownErr
    98  	}
    99  	if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
   100  		return 0, errDeadline
   101  	}
   102  	if len(p) == 0 {
   103  		return 0, nil
   104  	}
   105  
   106  	s.dataForWriting = p
   107  
   108  	var (
   109  		deadlineTimer  *utils.Timer
   110  		bytesWritten   int
   111  		notifiedSender bool
   112  	)
   113  	for {
   114  		var copied bool
   115  		var deadline time.Time
   116  		// As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame),
   117  		// which can the be popped the next time we assemble a packet.
   118  		// This allows us to return Write() when all data but x bytes have been sent out.
   119  		// When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame,
   120  		// allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN).
   121  		if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 {
   122  			if s.nextFrame == nil {
   123  				f := wire.GetStreamFrame()
   124  				f.Offset = s.writeOffset
   125  				f.StreamID = s.streamID
   126  				f.DataLenPresent = true
   127  				f.Data = f.Data[:len(s.dataForWriting)]
   128  				copy(f.Data, s.dataForWriting)
   129  				s.nextFrame = f
   130  			} else {
   131  				l := len(s.nextFrame.Data)
   132  				s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)]
   133  				copy(s.nextFrame.Data[l:], s.dataForWriting)
   134  			}
   135  			s.dataForWriting = nil
   136  			bytesWritten = len(p)
   137  			copied = true
   138  		} else {
   139  			bytesWritten = len(p) - len(s.dataForWriting)
   140  			deadline = s.deadline
   141  			if !deadline.IsZero() {
   142  				if !time.Now().Before(deadline) {
   143  					s.dataForWriting = nil
   144  					return bytesWritten, errDeadline
   145  				}
   146  				if deadlineTimer == nil {
   147  					deadlineTimer = utils.NewTimer()
   148  					defer deadlineTimer.Stop()
   149  				}
   150  				deadlineTimer.Reset(deadline)
   151  			}
   152  			if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown {
   153  				break
   154  			}
   155  		}
   156  
   157  		s.mutex.Unlock()
   158  		if !notifiedSender {
   159  			s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
   160  			notifiedSender = true
   161  		}
   162  		if copied {
   163  			s.mutex.Lock()
   164  			break
   165  		}
   166  		if deadline.IsZero() {
   167  			<-s.writeChan
   168  		} else {
   169  			select {
   170  			case <-s.writeChan:
   171  			case <-deadlineTimer.Chan():
   172  				deadlineTimer.SetRead()
   173  			}
   174  		}
   175  		s.mutex.Lock()
   176  	}
   177  
   178  	if bytesWritten == len(p) {
   179  		return bytesWritten, nil
   180  	}
   181  	if s.closeForShutdownErr != nil {
   182  		return bytesWritten, s.closeForShutdownErr
   183  	} else if s.cancelWriteErr != nil {
   184  		return bytesWritten, s.cancelWriteErr
   185  	}
   186  	return bytesWritten, nil
   187  }
   188  
   189  func (s *sendStream) canBufferStreamFrame() bool {
   190  	var l protocol.ByteCount
   191  	if s.nextFrame != nil {
   192  		l = s.nextFrame.DataLen()
   193  	}
   194  	return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize
   195  }
   196  
   197  // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
   198  // maxBytes is the maximum length this frame (including frame header) will have.
   199  func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool /* has more data to send */) {
   200  	s.mutex.Lock()
   201  	f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes)
   202  	if f != nil {
   203  		s.numOutstandingFrames++
   204  	}
   205  	s.mutex.Unlock()
   206  
   207  	if f == nil {
   208  		return nil, hasMoreData
   209  	}
   210  	return &ackhandler.Frame{Frame: f, OnLost: s.queueRetransmission, OnAcked: s.frameAcked}, hasMoreData
   211  }
   212  
   213  func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) {
   214  	if s.canceledWrite || s.closeForShutdownErr != nil {
   215  		return nil, false
   216  	}
   217  
   218  	if len(s.retransmissionQueue) > 0 {
   219  		f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes)
   220  		if f != nil || hasMoreRetransmissions {
   221  			if f == nil {
   222  				return nil, true
   223  			}
   224  			// We always claim that we have more data to send.
   225  			// This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
   226  			return f, true
   227  		}
   228  	}
   229  
   230  	if len(s.dataForWriting) == 0 && s.nextFrame == nil {
   231  		if s.finishedWriting && !s.finSent {
   232  			s.finSent = true
   233  			return &wire.StreamFrame{
   234  				StreamID:       s.streamID,
   235  				Offset:         s.writeOffset,
   236  				DataLenPresent: true,
   237  				Fin:            true,
   238  			}, false
   239  		}
   240  		return nil, false
   241  	}
   242  
   243  	sendWindow := s.flowController.SendWindowSize()
   244  	if sendWindow == 0 {
   245  		if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
   246  			s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
   247  				StreamID:          s.streamID,
   248  				MaximumStreamData: offset,
   249  			})
   250  			return nil, false
   251  		}
   252  		return nil, true
   253  	}
   254  
   255  	f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow)
   256  	if dataLen := f.DataLen(); dataLen > 0 {
   257  		s.writeOffset += f.DataLen()
   258  		s.flowController.AddBytesSent(f.DataLen())
   259  	}
   260  	f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent
   261  	if f.Fin {
   262  		s.finSent = true
   263  	}
   264  	return f, hasMoreData
   265  }
   266  
   267  func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount) (*wire.StreamFrame, bool) {
   268  	if s.nextFrame != nil {
   269  		nextFrame := s.nextFrame
   270  		s.nextFrame = nil
   271  
   272  		maxDataLen := utils.MinByteCount(sendWindow, nextFrame.MaxDataLen(maxBytes, s.version))
   273  		if nextFrame.DataLen() > maxDataLen {
   274  			s.nextFrame = wire.GetStreamFrame()
   275  			s.nextFrame.StreamID = s.streamID
   276  			s.nextFrame.Offset = s.writeOffset + maxDataLen
   277  			s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen]
   278  			s.nextFrame.DataLenPresent = true
   279  			copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:])
   280  			nextFrame.Data = nextFrame.Data[:maxDataLen]
   281  		} else {
   282  			s.signalWrite()
   283  		}
   284  		return nextFrame, s.nextFrame != nil || s.dataForWriting != nil
   285  	}
   286  
   287  	f := wire.GetStreamFrame()
   288  	f.Fin = false
   289  	f.StreamID = s.streamID
   290  	f.Offset = s.writeOffset
   291  	f.DataLenPresent = true
   292  	f.Data = f.Data[:0]
   293  
   294  	hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow)
   295  	if len(f.Data) == 0 && !f.Fin {
   296  		f.PutBack()
   297  		return nil, hasMoreData
   298  	}
   299  	return f, hasMoreData
   300  }
   301  
   302  func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount) bool {
   303  	maxDataLen := f.MaxDataLen(maxBytes, s.version)
   304  	if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
   305  		return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
   306  	}
   307  	s.getDataForWriting(f, utils.MinByteCount(maxDataLen, sendWindow))
   308  
   309  	return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
   310  }
   311  
   312  func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more retransmissions */) {
   313  	f := s.retransmissionQueue[0]
   314  	newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, s.version)
   315  	if needsSplit {
   316  		return newFrame, true
   317  	}
   318  	s.retransmissionQueue = s.retransmissionQueue[1:]
   319  	return f, len(s.retransmissionQueue) > 0
   320  }
   321  
   322  func (s *sendStream) hasData() bool {
   323  	s.mutex.Lock()
   324  	hasData := len(s.dataForWriting) > 0
   325  	s.mutex.Unlock()
   326  	return hasData
   327  }
   328  
   329  func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
   330  	if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes {
   331  		f.Data = f.Data[:len(s.dataForWriting)]
   332  		copy(f.Data, s.dataForWriting)
   333  		s.dataForWriting = nil
   334  		s.signalWrite()
   335  		return
   336  	}
   337  	f.Data = f.Data[:maxBytes]
   338  	copy(f.Data, s.dataForWriting)
   339  	s.dataForWriting = s.dataForWriting[maxBytes:]
   340  	if s.canBufferStreamFrame() {
   341  		s.signalWrite()
   342  	}
   343  }
   344  
   345  func (s *sendStream) frameAcked(f wire.Frame) {
   346  	f.(*wire.StreamFrame).PutBack()
   347  
   348  	s.mutex.Lock()
   349  	if s.canceledWrite {
   350  		s.mutex.Unlock()
   351  		return
   352  	}
   353  	s.numOutstandingFrames--
   354  	if s.numOutstandingFrames < 0 {
   355  		panic("numOutStandingFrames negative")
   356  	}
   357  	newlyCompleted := s.isNewlyCompleted()
   358  	s.mutex.Unlock()
   359  
   360  	if newlyCompleted {
   361  		s.sender.onStreamCompleted(s.streamID)
   362  	}
   363  }
   364  
   365  func (s *sendStream) isNewlyCompleted() bool {
   366  	completed := (s.finSent || s.canceledWrite) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
   367  	if completed && !s.completed {
   368  		s.completed = true
   369  		return true
   370  	}
   371  	return false
   372  }
   373  
   374  func (s *sendStream) queueRetransmission(f wire.Frame) {
   375  	sf := f.(*wire.StreamFrame)
   376  	sf.DataLenPresent = true
   377  	s.mutex.Lock()
   378  	if s.canceledWrite {
   379  		s.mutex.Unlock()
   380  		return
   381  	}
   382  	s.retransmissionQueue = append(s.retransmissionQueue, sf)
   383  	s.numOutstandingFrames--
   384  	if s.numOutstandingFrames < 0 {
   385  		panic("numOutStandingFrames negative")
   386  	}
   387  	s.mutex.Unlock()
   388  
   389  	s.sender.onHasStreamData(s.streamID)
   390  }
   391  
   392  func (s *sendStream) Close() error {
   393  	s.mutex.Lock()
   394  	if s.closedForShutdown {
   395  		s.mutex.Unlock()
   396  		return nil
   397  	}
   398  	if s.canceledWrite {
   399  		s.mutex.Unlock()
   400  		return fmt.Errorf("close called for canceled stream %d", s.streamID)
   401  	}
   402  	s.ctxCancel()
   403  	s.finishedWriting = true
   404  	s.mutex.Unlock()
   405  
   406  	s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
   407  	return nil
   408  }
   409  
   410  func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
   411  	s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode))
   412  }
   413  
   414  // must be called after locking the mutex
   415  func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr error) {
   416  	s.mutex.Lock()
   417  	if s.canceledWrite {
   418  		s.mutex.Unlock()
   419  		return
   420  	}
   421  	s.ctxCancel()
   422  	s.canceledWrite = true
   423  	s.cancelWriteErr = writeErr
   424  	s.numOutstandingFrames = 0
   425  	s.retransmissionQueue = nil
   426  	newlyCompleted := s.isNewlyCompleted()
   427  	s.mutex.Unlock()
   428  
   429  	s.signalWrite()
   430  	s.sender.queueControlFrame(&wire.ResetStreamFrame{
   431  		StreamID:  s.streamID,
   432  		FinalSize: s.writeOffset,
   433  		ErrorCode: errorCode,
   434  	})
   435  	if newlyCompleted {
   436  		s.sender.onStreamCompleted(s.streamID)
   437  	}
   438  }
   439  
   440  func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
   441  	s.mutex.Lock()
   442  	hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
   443  	s.mutex.Unlock()
   444  
   445  	s.flowController.UpdateSendWindow(limit)
   446  	if hasStreamData {
   447  		s.sender.onHasStreamData(s.streamID)
   448  	}
   449  }
   450  
   451  func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
   452  	s.cancelWriteImpl(frame.ErrorCode, &StreamError{
   453  		StreamID:  s.streamID,
   454  		ErrorCode: frame.ErrorCode,
   455  	})
   456  }
   457  
   458  func (s *sendStream) Context() context.Context {
   459  	return s.ctx
   460  }
   461  
   462  func (s *sendStream) SetWriteDeadline(t time.Time) error {
   463  	s.mutex.Lock()
   464  	s.deadline = t
   465  	s.mutex.Unlock()
   466  	s.signalWrite()
   467  	return nil
   468  }
   469  
   470  // CloseForShutdown closes a stream abruptly.
   471  // It makes Write unblock (and return the error) immediately.
   472  // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
   473  func (s *sendStream) closeForShutdown(err error) {
   474  	s.mutex.Lock()
   475  	s.ctxCancel()
   476  	s.closedForShutdown = true
   477  	s.closeForShutdownErr = err
   478  	s.mutex.Unlock()
   479  	s.signalWrite()
   480  }
   481  
   482  // signalWrite performs a non-blocking send on the writeChan
   483  func (s *sendStream) signalWrite() {
   484  	select {
   485  	case s.writeChan <- struct{}{}:
   486  	default:
   487  	}
   488  }