github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/send_stream.go (about) 1 package quic 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 "time" 8 9 "github.com/daeuniverse/quic-go/internal/ackhandler" 10 "github.com/daeuniverse/quic-go/internal/flowcontrol" 11 "github.com/daeuniverse/quic-go/internal/protocol" 12 "github.com/daeuniverse/quic-go/internal/qerr" 13 "github.com/daeuniverse/quic-go/internal/utils" 14 "github.com/daeuniverse/quic-go/internal/wire" 15 ) 16 17 type sendStreamI interface { 18 SendStream 19 handleStopSendingFrame(*wire.StopSendingFrame) 20 hasData() bool 21 popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (frame ackhandler.StreamFrame, ok, hasMore bool) 22 closeForShutdown(error) 23 updateSendWindow(protocol.ByteCount) 24 } 25 26 type sendStream struct { 27 mutex sync.Mutex 28 29 numOutstandingFrames int64 30 retransmissionQueue []*wire.StreamFrame 31 32 ctx context.Context 33 ctxCancel context.CancelCauseFunc 34 35 streamID protocol.StreamID 36 sender streamSender 37 38 writeOffset protocol.ByteCount 39 40 cancelWriteErr error 41 closeForShutdownErr error 42 43 finishedWriting bool // set once Close() is called 44 finSent bool // set when a STREAM_FRAME with FIN bit has been sent 45 completed bool // set when this stream has been reported to the streamSender as completed 46 47 dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out 48 nextFrame *wire.StreamFrame 49 50 writeChan chan struct{} 51 writeOnce chan struct{} 52 deadline time.Time 53 54 flowController flowcontrol.StreamFlowController 55 } 56 57 var ( 58 _ SendStream = &sendStream{} 59 _ sendStreamI = &sendStream{} 60 ) 61 62 func newSendStream( 63 streamID protocol.StreamID, 64 sender streamSender, 65 flowController flowcontrol.StreamFlowController, 66 ) *sendStream { 67 s := &sendStream{ 68 streamID: streamID, 69 sender: sender, 70 flowController: flowController, 71 writeChan: make(chan struct{}, 1), 72 writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write 73 } 74 s.ctx, s.ctxCancel = context.WithCancelCause(context.Background()) 75 return s 76 } 77 78 func (s *sendStream) StreamID() protocol.StreamID { 79 return s.streamID // same for receiveStream and sendStream 80 } 81 82 func (s *sendStream) Write(p []byte) (int, error) { 83 // Concurrent use of Write is not permitted (and doesn't make any sense), 84 // but sometimes people do it anyway. 85 // Make sure that we only execute one call at any given time to avoid hard to debug failures. 86 s.writeOnce <- struct{}{} 87 defer func() { <-s.writeOnce }() 88 89 s.mutex.Lock() 90 defer s.mutex.Unlock() 91 92 if s.finishedWriting { 93 return 0, fmt.Errorf("write on closed stream %d", s.streamID) 94 } 95 if s.cancelWriteErr != nil { 96 return 0, s.cancelWriteErr 97 } 98 if s.closeForShutdownErr != nil { 99 return 0, s.closeForShutdownErr 100 } 101 if !s.deadline.IsZero() && !time.Now().Before(s.deadline) { 102 return 0, errDeadline 103 } 104 if len(p) == 0 { 105 return 0, nil 106 } 107 108 s.dataForWriting = p 109 110 var ( 111 deadlineTimer *utils.Timer 112 bytesWritten int 113 notifiedSender bool 114 ) 115 for { 116 var copied bool 117 var deadline time.Time 118 // As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame), 119 // which can then be popped the next time we assemble a packet. 120 // This allows us to return Write() when all data but x bytes have been sent out. 121 // When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame, 122 // allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN). 123 if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 { 124 if s.nextFrame == nil { 125 f := wire.GetStreamFrame() 126 f.Offset = s.writeOffset 127 f.StreamID = s.streamID 128 f.DataLenPresent = true 129 f.Data = f.Data[:len(s.dataForWriting)] 130 copy(f.Data, s.dataForWriting) 131 s.nextFrame = f 132 } else { 133 l := len(s.nextFrame.Data) 134 s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)] 135 copy(s.nextFrame.Data[l:], s.dataForWriting) 136 } 137 s.dataForWriting = nil 138 bytesWritten = len(p) 139 copied = true 140 } else { 141 bytesWritten = len(p) - len(s.dataForWriting) 142 deadline = s.deadline 143 if !deadline.IsZero() { 144 if !time.Now().Before(deadline) { 145 s.dataForWriting = nil 146 return bytesWritten, errDeadline 147 } 148 if deadlineTimer == nil { 149 deadlineTimer = utils.NewTimer() 150 defer deadlineTimer.Stop() 151 } 152 deadlineTimer.Reset(deadline) 153 } 154 if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil { 155 break 156 } 157 } 158 159 s.mutex.Unlock() 160 if !notifiedSender { 161 s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex 162 notifiedSender = true 163 } 164 if copied { 165 s.mutex.Lock() 166 break 167 } 168 if deadline.IsZero() { 169 <-s.writeChan 170 } else { 171 select { 172 case <-s.writeChan: 173 case <-deadlineTimer.Chan(): 174 deadlineTimer.SetRead() 175 } 176 } 177 s.mutex.Lock() 178 } 179 180 if bytesWritten == len(p) { 181 return bytesWritten, nil 182 } 183 if s.closeForShutdownErr != nil { 184 return bytesWritten, s.closeForShutdownErr 185 } else if s.cancelWriteErr != nil { 186 return bytesWritten, s.cancelWriteErr 187 } 188 return bytesWritten, nil 189 } 190 191 func (s *sendStream) canBufferStreamFrame() bool { 192 var l protocol.ByteCount 193 if s.nextFrame != nil { 194 l = s.nextFrame.DataLen() 195 } 196 return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize 197 } 198 199 // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream 200 // maxBytes is the maximum length this frame (including frame header) will have. 201 func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) { 202 s.mutex.Lock() 203 f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v) 204 if f != nil { 205 s.numOutstandingFrames++ 206 } 207 s.mutex.Unlock() 208 209 if f == nil { 210 return ackhandler.StreamFrame{}, false, hasMoreData 211 } 212 return ackhandler.StreamFrame{ 213 Frame: f, 214 Handler: (*sendStreamAckHandler)(s), 215 }, true, hasMoreData 216 } 217 218 func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more data to send */) { 219 if s.cancelWriteErr != nil || s.closeForShutdownErr != nil { 220 return nil, false 221 } 222 223 if len(s.retransmissionQueue) > 0 { 224 f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v) 225 if f != nil || hasMoreRetransmissions { 226 if f == nil { 227 return nil, true 228 } 229 // We always claim that we have more data to send. 230 // This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future. 231 return f, true 232 } 233 } 234 235 if len(s.dataForWriting) == 0 && s.nextFrame == nil { 236 if s.finishedWriting && !s.finSent { 237 s.finSent = true 238 return &wire.StreamFrame{ 239 StreamID: s.streamID, 240 Offset: s.writeOffset, 241 DataLenPresent: true, 242 Fin: true, 243 }, false 244 } 245 return nil, false 246 } 247 248 sendWindow := s.flowController.SendWindowSize() 249 if sendWindow == 0 { 250 if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked { 251 s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{ 252 StreamID: s.streamID, 253 MaximumStreamData: offset, 254 }) 255 return nil, false 256 } 257 return nil, true 258 } 259 260 f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow, v) 261 if dataLen := f.DataLen(); dataLen > 0 { 262 s.writeOffset += f.DataLen() 263 s.flowController.AddBytesSent(f.DataLen()) 264 } 265 f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent 266 if f.Fin { 267 s.finSent = true 268 } 269 return f, hasMoreData 270 } 271 272 func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) { 273 if s.nextFrame != nil { 274 nextFrame := s.nextFrame 275 s.nextFrame = nil 276 277 maxDataLen := min(sendWindow, nextFrame.MaxDataLen(maxBytes, v)) 278 if nextFrame.DataLen() > maxDataLen { 279 s.nextFrame = wire.GetStreamFrame() 280 s.nextFrame.StreamID = s.streamID 281 s.nextFrame.Offset = s.writeOffset + maxDataLen 282 s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen] 283 s.nextFrame.DataLenPresent = true 284 copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:]) 285 nextFrame.Data = nextFrame.Data[:maxDataLen] 286 } else { 287 s.signalWrite() 288 } 289 return nextFrame, s.nextFrame != nil || s.dataForWriting != nil 290 } 291 292 f := wire.GetStreamFrame() 293 f.Fin = false 294 f.StreamID = s.streamID 295 f.Offset = s.writeOffset 296 f.DataLenPresent = true 297 f.Data = f.Data[:0] 298 299 hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow, v) 300 if len(f.Data) == 0 && !f.Fin { 301 f.PutBack() 302 return nil, hasMoreData 303 } 304 return f, hasMoreData 305 } 306 307 func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool { 308 maxDataLen := f.MaxDataLen(maxBytes, v) 309 if maxDataLen == 0 { // a STREAM frame must have at least one byte of data 310 return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting 311 } 312 s.getDataForWriting(f, min(maxDataLen, sendWindow)) 313 314 return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting 315 } 316 317 func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) { 318 f := s.retransmissionQueue[0] 319 newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v) 320 if needsSplit { 321 return newFrame, true 322 } 323 s.retransmissionQueue = s.retransmissionQueue[1:] 324 return f, len(s.retransmissionQueue) > 0 325 } 326 327 func (s *sendStream) hasData() bool { 328 s.mutex.Lock() 329 hasData := len(s.dataForWriting) > 0 330 s.mutex.Unlock() 331 return hasData 332 } 333 334 func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) { 335 if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes { 336 f.Data = f.Data[:len(s.dataForWriting)] 337 copy(f.Data, s.dataForWriting) 338 s.dataForWriting = nil 339 s.signalWrite() 340 return 341 } 342 f.Data = f.Data[:maxBytes] 343 copy(f.Data, s.dataForWriting) 344 s.dataForWriting = s.dataForWriting[maxBytes:] 345 if s.canBufferStreamFrame() { 346 s.signalWrite() 347 } 348 } 349 350 func (s *sendStream) isNewlyCompleted() bool { 351 completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 352 if completed && !s.completed { 353 s.completed = true 354 return true 355 } 356 return false 357 } 358 359 func (s *sendStream) Close() error { 360 s.mutex.Lock() 361 if s.closeForShutdownErr != nil { 362 s.mutex.Unlock() 363 return nil 364 } 365 if s.cancelWriteErr != nil { 366 s.mutex.Unlock() 367 return fmt.Errorf("close called for canceled stream %d", s.streamID) 368 } 369 s.ctxCancel(nil) 370 s.finishedWriting = true 371 s.mutex.Unlock() 372 373 s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex 374 return nil 375 } 376 377 func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { 378 s.cancelWriteImpl(errorCode, false) 379 } 380 381 // must be called after locking the mutex 382 func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) { 383 s.mutex.Lock() 384 if s.cancelWriteErr != nil { 385 s.mutex.Unlock() 386 return 387 } 388 s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote} 389 s.ctxCancel(s.cancelWriteErr) 390 s.numOutstandingFrames = 0 391 s.retransmissionQueue = nil 392 newlyCompleted := s.isNewlyCompleted() 393 s.mutex.Unlock() 394 395 s.signalWrite() 396 s.sender.queueControlFrame(&wire.ResetStreamFrame{ 397 StreamID: s.streamID, 398 FinalSize: s.writeOffset, 399 ErrorCode: errorCode, 400 }) 401 if newlyCompleted { 402 s.sender.onStreamCompleted(s.streamID) 403 } 404 } 405 406 func (s *sendStream) updateSendWindow(limit protocol.ByteCount) { 407 updated := s.flowController.UpdateSendWindow(limit) 408 if !updated { // duplicate or reordered MAX_STREAM_DATA frame 409 return 410 } 411 s.mutex.Lock() 412 hasStreamData := s.dataForWriting != nil || s.nextFrame != nil 413 s.mutex.Unlock() 414 if hasStreamData { 415 s.sender.onHasStreamData(s.streamID) 416 } 417 } 418 419 func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { 420 s.cancelWriteImpl(frame.ErrorCode, true) 421 } 422 423 func (s *sendStream) Context() context.Context { 424 return s.ctx 425 } 426 427 func (s *sendStream) SetWriteDeadline(t time.Time) error { 428 s.mutex.Lock() 429 s.deadline = t 430 s.mutex.Unlock() 431 s.signalWrite() 432 return nil 433 } 434 435 // CloseForShutdown closes a stream abruptly. 436 // It makes Write unblock (and return the error) immediately. 437 // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. 438 func (s *sendStream) closeForShutdown(err error) { 439 s.mutex.Lock() 440 s.ctxCancel(err) 441 s.closeForShutdownErr = err 442 s.mutex.Unlock() 443 s.signalWrite() 444 } 445 446 // signalWrite performs a non-blocking send on the writeChan 447 func (s *sendStream) signalWrite() { 448 select { 449 case s.writeChan <- struct{}{}: 450 default: 451 } 452 } 453 454 type sendStreamAckHandler sendStream 455 456 var _ ackhandler.FrameHandler = &sendStreamAckHandler{} 457 458 func (s *sendStreamAckHandler) OnAcked(f wire.Frame) { 459 sf := f.(*wire.StreamFrame) 460 sf.PutBack() 461 s.mutex.Lock() 462 if s.cancelWriteErr != nil { 463 s.mutex.Unlock() 464 return 465 } 466 s.numOutstandingFrames-- 467 if s.numOutstandingFrames < 0 { 468 panic("numOutStandingFrames negative") 469 } 470 newlyCompleted := (*sendStream)(s).isNewlyCompleted() 471 s.mutex.Unlock() 472 473 if newlyCompleted { 474 s.sender.onStreamCompleted(s.streamID) 475 } 476 } 477 478 func (s *sendStreamAckHandler) OnLost(f wire.Frame) { 479 sf := f.(*wire.StreamFrame) 480 s.mutex.Lock() 481 if s.cancelWriteErr != nil { 482 s.mutex.Unlock() 483 return 484 } 485 sf.DataLenPresent = true 486 s.retransmissionQueue = append(s.retransmissionQueue, sf) 487 s.numOutstandingFrames-- 488 if s.numOutstandingFrames < 0 { 489 panic("numOutStandingFrames negative") 490 } 491 s.mutex.Unlock() 492 493 s.sender.onHasStreamData(s.streamID) 494 }