github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/send_stream.go (about)

     1  package gquic
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/flowcontrol"
    10  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
    11  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
    12  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/wire"
    13  )
    14  
    15  type sendStreamI interface {
    16  	SendStream
    17  	handleStopSendingFrame(*wire.StopSendingFrame)
    18  	hasData() bool
    19  	popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
    20  	closeForShutdown(error)
    21  	handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
    22  }
    23  
    24  type sendStream struct {
    25  	mutex sync.Mutex
    26  
    27  	ctx       context.Context
    28  	ctxCancel context.CancelFunc
    29  
    30  	streamID protocol.StreamID
    31  	sender   streamSender
    32  
    33  	writeOffset protocol.ByteCount
    34  
    35  	cancelWriteErr      error
    36  	closeForShutdownErr error
    37  
    38  	closedForShutdown bool // set when CloseForShutdown() is called
    39  	finishedWriting   bool // set once Close() is called
    40  	canceledWrite     bool // set when CancelWrite() is called, or a STOP_SENDING frame is received
    41  	finSent           bool // set when a STREAM_FRAME with FIN bit has b
    42  
    43  	dataForWriting []byte
    44  
    45  	writeChan chan struct{}
    46  	deadline  time.Time
    47  
    48  	flowController flowcontrol.StreamFlowController
    49  
    50  	version protocol.VersionNumber
    51  }
    52  
    53  var _ SendStream = &sendStream{}
    54  var _ sendStreamI = &sendStream{}
    55  
    56  func newSendStream(
    57  	streamID protocol.StreamID,
    58  	sender streamSender,
    59  	flowController flowcontrol.StreamFlowController,
    60  	version protocol.VersionNumber,
    61  ) *sendStream {
    62  	s := &sendStream{
    63  		streamID:       streamID,
    64  		sender:         sender,
    65  		flowController: flowController,
    66  		writeChan:      make(chan struct{}, 1),
    67  		version:        version,
    68  	}
    69  	s.ctx, s.ctxCancel = context.WithCancel(context.Background())
    70  	return s
    71  }
    72  
    73  func (s *sendStream) StreamID() protocol.StreamID {
    74  	return s.streamID // same for receiveStream and sendStream
    75  }
    76  
    77  func (s *sendStream) Write(p []byte) (int, error) {
    78  	s.mutex.Lock()
    79  	defer s.mutex.Unlock()
    80  
    81  	if s.finishedWriting {
    82  		return 0, fmt.Errorf("write on closed stream %d", s.streamID)
    83  	}
    84  	if s.canceledWrite {
    85  		return 0, s.cancelWriteErr
    86  	}
    87  	if s.closeForShutdownErr != nil {
    88  		return 0, s.closeForShutdownErr
    89  	}
    90  	if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
    91  		return 0, errDeadline
    92  	}
    93  	if len(p) == 0 {
    94  		return 0, nil
    95  	}
    96  
    97  	s.dataForWriting = p
    98  
    99  	var (
   100  		deadlineTimer  *utils.Timer
   101  		bytesWritten   int
   102  		notifiedSender bool
   103  	)
   104  	for {
   105  		bytesWritten = len(p) - len(s.dataForWriting)
   106  		deadline := s.deadline
   107  		if !deadline.IsZero() {
   108  			if !time.Now().Before(deadline) {
   109  				s.dataForWriting = nil
   110  				return bytesWritten, errDeadline
   111  			}
   112  			if deadlineTimer == nil {
   113  				deadlineTimer = utils.NewTimer()
   114  			}
   115  			deadlineTimer.Reset(deadline)
   116  		}
   117  		if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown {
   118  			break
   119  		}
   120  
   121  		s.mutex.Unlock()
   122  		if !notifiedSender {
   123  			s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
   124  			notifiedSender = true
   125  		}
   126  		if deadline.IsZero() {
   127  			<-s.writeChan
   128  		} else {
   129  			select {
   130  			case <-s.writeChan:
   131  			case <-deadlineTimer.Chan():
   132  				deadlineTimer.SetRead()
   133  			}
   134  		}
   135  		s.mutex.Lock()
   136  	}
   137  
   138  	// [Psiphon]
   139  	// Stop timer to immediately release resources
   140  	if deadlineTimer != nil {
   141  		deadlineTimer.Reset(time.Time{})
   142  	}
   143  
   144  	if s.closeForShutdownErr != nil {
   145  		return bytesWritten, s.closeForShutdownErr
   146  	} else if s.cancelWriteErr != nil {
   147  		return bytesWritten, s.cancelWriteErr
   148  	}
   149  	return bytesWritten, nil
   150  }
   151  
   152  // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
   153  // maxBytes is the maximum length this frame (including frame header) will have.
   154  func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) {
   155  	completed, frame, hasMoreData := s.popStreamFrameImpl(maxBytes)
   156  	if completed {
   157  		s.sender.onStreamCompleted(s.streamID)
   158  	}
   159  	return frame, hasMoreData
   160  }
   161  
   162  func (s *sendStream) popStreamFrameImpl(maxBytes protocol.ByteCount) (bool /* completed */, *wire.StreamFrame, bool /* has more data to send */) {
   163  	s.mutex.Lock()
   164  	defer s.mutex.Unlock()
   165  
   166  	if s.closeForShutdownErr != nil {
   167  		return false, nil, false
   168  	}
   169  
   170  	frame := &wire.StreamFrame{
   171  		StreamID:       s.streamID,
   172  		Offset:         s.writeOffset,
   173  		DataLenPresent: true,
   174  	}
   175  	maxDataLen := frame.MaxDataLen(maxBytes, s.version)
   176  	if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
   177  		return false, nil, s.dataForWriting != nil
   178  	}
   179  	frame.Data, frame.FinBit = s.getDataForWriting(maxDataLen)
   180  	if len(frame.Data) == 0 && !frame.FinBit {
   181  		// this can happen if:
   182  		// - popStreamFrame is called but there's no data for writing
   183  		// - there's data for writing, but the stream is stream-level flow control blocked
   184  		// - there's data for writing, but the stream is connection-level flow control blocked
   185  		if s.dataForWriting == nil {
   186  			return false, nil, false
   187  		}
   188  		if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
   189  			s.sender.queueControlFrame(&wire.StreamBlockedFrame{
   190  				StreamID: s.streamID,
   191  				Offset:   offset,
   192  			})
   193  			return false, nil, false
   194  		}
   195  		return false, nil, true
   196  	}
   197  	if frame.FinBit {
   198  		s.finSent = true
   199  	}
   200  	return frame.FinBit, frame, s.dataForWriting != nil
   201  }
   202  
   203  func (s *sendStream) hasData() bool {
   204  	s.mutex.Lock()
   205  	hasData := len(s.dataForWriting) > 0
   206  	s.mutex.Unlock()
   207  	return hasData
   208  }
   209  
   210  func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) {
   211  	if s.dataForWriting == nil {
   212  		return nil, s.finishedWriting && !s.finSent
   213  	}
   214  
   215  	if s.streamID != s.version.CryptoStreamID() {
   216  		maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize())
   217  	}
   218  	if maxBytes == 0 {
   219  		return nil, false
   220  	}
   221  
   222  	var ret []byte
   223  	if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
   224  		ret = make([]byte, int(maxBytes))
   225  		copy(ret, s.dataForWriting[:maxBytes])
   226  		s.dataForWriting = s.dataForWriting[maxBytes:]
   227  	} else {
   228  		ret = make([]byte, len(s.dataForWriting))
   229  		copy(ret, s.dataForWriting)
   230  		s.dataForWriting = nil
   231  		s.signalWrite()
   232  	}
   233  	s.writeOffset += protocol.ByteCount(len(ret))
   234  	s.flowController.AddBytesSent(protocol.ByteCount(len(ret)))
   235  	return ret, s.finishedWriting && s.dataForWriting == nil && !s.finSent
   236  }
   237  
   238  func (s *sendStream) Close() error {
   239  	s.mutex.Lock()
   240  	if s.canceledWrite {
   241  		s.mutex.Unlock()
   242  		return fmt.Errorf("Close called for canceled stream %d", s.streamID)
   243  	}
   244  	s.finishedWriting = true
   245  	s.mutex.Unlock()
   246  
   247  	s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
   248  	s.ctxCancel()
   249  	return nil
   250  }
   251  
   252  func (s *sendStream) CancelWrite(errorCode protocol.ApplicationErrorCode) error {
   253  	s.mutex.Lock()
   254  	completed, err := s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode))
   255  	s.mutex.Unlock()
   256  
   257  	if completed {
   258  		s.sender.onStreamCompleted(s.streamID) // must be called without holding the mutex
   259  	}
   260  	return err
   261  }
   262  
   263  // must be called after locking the mutex
   264  func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, writeErr error) (bool /*completed */, error) {
   265  	if s.canceledWrite {
   266  		return false, nil
   267  	}
   268  	if s.finishedWriting {
   269  		return false, fmt.Errorf("CancelWrite for closed stream %d", s.streamID)
   270  	}
   271  	s.canceledWrite = true
   272  	s.cancelWriteErr = writeErr
   273  	s.signalWrite()
   274  	s.sender.queueControlFrame(&wire.RstStreamFrame{
   275  		StreamID:   s.streamID,
   276  		ByteOffset: s.writeOffset,
   277  		ErrorCode:  errorCode,
   278  	})
   279  	// TODO(#991): cancel retransmissions for this stream
   280  	s.ctxCancel()
   281  	return true, nil
   282  }
   283  
   284  func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
   285  	if completed := s.handleStopSendingFrameImpl(frame); completed {
   286  		s.sender.onStreamCompleted(s.streamID)
   287  	}
   288  }
   289  
   290  func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) {
   291  	s.mutex.Lock()
   292  	hasStreamData := s.dataForWriting != nil
   293  	s.mutex.Unlock()
   294  	s.flowController.UpdateSendWindow(frame.ByteOffset)
   295  	if hasStreamData {
   296  		s.sender.onHasStreamData(s.streamID)
   297  	}
   298  }
   299  
   300  // must be called after locking the mutex
   301  func (s *sendStream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) bool /*completed*/ {
   302  	s.mutex.Lock()
   303  	defer s.mutex.Unlock()
   304  
   305  	writeErr := streamCanceledError{
   306  		errorCode: frame.ErrorCode,
   307  		error:     fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode),
   308  	}
   309  	errorCode := errorCodeStopping
   310  	if !s.version.UsesIETFFrameFormat() {
   311  		errorCode = errorCodeStoppingGQUIC
   312  	}
   313  	completed, _ := s.cancelWriteImpl(errorCode, writeErr)
   314  	return completed
   315  }
   316  
   317  func (s *sendStream) Context() context.Context {
   318  	return s.ctx
   319  }
   320  
   321  func (s *sendStream) SetWriteDeadline(t time.Time) error {
   322  	s.mutex.Lock()
   323  	s.deadline = t
   324  	s.mutex.Unlock()
   325  	s.signalWrite()
   326  	return nil
   327  }
   328  
   329  // CloseForShutdown closes a stream abruptly.
   330  // It makes Write unblock (and return the error) immediately.
   331  // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
   332  func (s *sendStream) closeForShutdown(err error) {
   333  	s.mutex.Lock()
   334  	s.closedForShutdown = true
   335  	s.closeForShutdownErr = err
   336  	s.mutex.Unlock()
   337  	s.signalWrite()
   338  	s.ctxCancel()
   339  }
   340  
   341  func (s *sendStream) getWriteOffset() protocol.ByteCount {
   342  	return s.writeOffset
   343  }
   344  
   345  // signalWrite performs a non-blocking send on the writeChan
   346  func (s *sendStream) signalWrite() {
   347  	select {
   348  	case s.writeChan <- struct{}{}:
   349  	default:
   350  	}
   351  }