github.com/mutagen-io/mutagen@v0.18.0-rc1/pkg/multiplexing/multiplexer.go (about) 1 package multiplexing 2 3 import ( 4 "context" 5 "encoding/binary" 6 "errors" 7 "fmt" 8 "io" 9 "math" 10 "net" 11 "sync" 12 "time" 13 14 "github.com/mutagen-io/mutagen/pkg/multiplexing/ring" 15 ) 16 17 var ( 18 // ErrMultiplexerClosed is returned from operations that fail due to a 19 // multiplexer being closed. 20 ErrMultiplexerClosed = errors.New("multiplexer closed") 21 // ErrStreamRejected is returned from open operations that fail due to the 22 // remote endpoint rejecting the open request. 23 ErrStreamRejected = errors.New("stream rejected") 24 ) 25 26 // windowIncrement is used to pass a window increment from a stream to the 27 // multiplexer. 28 type windowIncrement struct { 29 // stream is the stream identifier. 30 stream uint64 31 // amount is the increment amount. 32 amount uint64 33 } 34 35 // Multiplexer provides bidirectional stream multiplexing. 36 type Multiplexer struct { 37 // even indicates whether or not the multiplexer uses even-numbered outbound 38 // stream identifiers. 39 even bool 40 // configuration is the multiplexer configuration. 41 configuration *Configuration 42 43 // closeOnce guards closure of closer and closed. 44 closeOnce sync.Once 45 // closer closes the underlying carrier. 46 closer io.Closer 47 // closed is closed when the underlying carrier is closed. 48 closed chan struct{} 49 // internalErrorLock guards access to internalError. 50 internalErrorLock sync.RWMutex 51 // internalError records the error associated with closure, if any. 52 internalError error 53 54 // streamLock guards nextOutboundStreamIdentifier and streams. 55 streamLock sync.Mutex 56 // nextOutboundStreamIdentifier is the next outbound stream identifier that 57 // will be used. It is set to 0 when outbound identifiers are exhausted. 58 nextOutboundStreamIdentifier uint64 59 // streams maps stream identifiers to their corresponding local stream 60 // objects. Stream objects perform their own deregistration when closed. 61 streams map[uint64]*Stream 62 // pendingInboundStreamIdentifiers is the backlog of pending inbound stream 63 // identifiers waiting to be accepted. It is written to only by the reader 64 // Goroutine. It has a capacity equal to the accept backlog size. 65 pendingInboundStreamIdentifiers chan uint64 66 67 // writeBufferAvailable is the channel where empty outbound message buffers 68 // are stored. If a buffer is in this channel, it is guaranteed to have 69 // sufficient free space to buffer any single message. Pollers on this 70 // channel should always poll on closed simultaneously and terminate if 71 // closed is closed. 72 writeBufferAvailable chan *messageBuffer 73 // writeBufferPending is the channel where non-empty outbound message 74 // buffers should be placed to enqueue them for transmission. Writes to this 75 // channel are only allowed by holders of outbound message buffers and are 76 // guaranteed never to block. 77 writeBufferPending chan *messageBuffer 78 79 // enqueueWindowIncrement enqueues transmission of a stream receive window 80 // increment message. The amount will be added to any pending window 81 // increment. This channel is unbuffered, but guaranteed to be approximately 82 // non-blocking as long as the multiplexer is not closed (as indicated by 83 // the closed channel). 84 enqueueWindowIncrement chan windowIncrement 85 // enqueueCloseWrite enqueues transmission of a stream close write message. 86 // It should be provided with the stream identifier. This channel is 87 // unbuffered, but guaranteed to be approximately non-blocking as long as 88 // the multiplexer is not closed (as indicated by the closed channel). 89 enqueueCloseWrite chan uint64 90 // enqueueClose enqueues transmission of a stream close message. It should 91 // be provided with the stream identifier. Any pending window increment or 92 // close write messages will be cancelled. This channel is unbuffered, but 93 // guaranteed to be approximately non-blocking as long as the multiplexer is 94 // not closed (as indicated by the closed channel). 95 enqueueClose chan uint64 96 } 97 98 // Multiplex creates a new multiplexer on top of an existing carrier stream. The 99 // multiplexer takes ownership of the carrier, so it should not be used directly 100 // after being passed to this function. 101 // 102 // Multiplexers are symmetric, meaning that a multiplexer at either end of the 103 // carrier can both open and accept connections. However, a single asymmetric 104 // parameter is required to avoid the need for negotiating stream identifiers, 105 // so the even parameter must be set to true on one endpoint and false on the 106 // other (using some implicit or out-of-band coordination mechanism, such as 107 // false for client and true for server). The value of even has no observable 108 // effect on the multiplexer, other than determining the evenness of outbound 109 // stream identifiers. 110 // 111 // If configuration is nil, the default configuration will be used. 112 func Multiplex(carrier Carrier, even bool, configuration *Configuration) *Multiplexer { 113 // If no configuration was provided, then use default values, otherwise 114 // normalize any out-of-range values provided by the caller. 115 if configuration == nil { 116 configuration = DefaultConfiguration() 117 } else { 118 configuration.normalize() 119 } 120 121 // Create the multiplexer. 122 multiplexer := &Multiplexer{ 123 even: even, 124 configuration: configuration, 125 closer: carrier, 126 closed: make(chan struct{}), 127 streams: make(map[uint64]*Stream), 128 pendingInboundStreamIdentifiers: make(chan uint64, configuration.AcceptBacklog), 129 writeBufferAvailable: make(chan *messageBuffer, configuration.WriteBufferCount), 130 writeBufferPending: make(chan *messageBuffer, configuration.WriteBufferCount), 131 enqueueWindowIncrement: make(chan windowIncrement), 132 enqueueCloseWrite: make(chan uint64), 133 enqueueClose: make(chan uint64), 134 } 135 if even { 136 multiplexer.nextOutboundStreamIdentifier = 2 137 } else { 138 multiplexer.nextOutboundStreamIdentifier = 1 139 } 140 for i := 0; i < configuration.WriteBufferCount; i++ { 141 multiplexer.writeBufferAvailable <- newMessageBuffer() 142 } 143 144 // Start the multiplexer's background Goroutines. 145 go multiplexer.run(carrier) 146 147 // Done. 148 return multiplexer 149 } 150 151 // run is the primary entry point for the multiplexer's background Goroutines. 152 func (m *Multiplexer) run(carrier Carrier) { 153 // Start the reader Goroutine and monitor for its termination. 154 heartbeats := make(chan struct{}, 1) 155 readErrors := make(chan error, 1) 156 go func() { 157 readErrors <- m.read(carrier, heartbeats) 158 }() 159 160 // Start the writer Goroutine and monitor for its termination. 161 writeErrors := make(chan error, 1) 162 go func() { 163 writeErrors <- m.write(carrier) 164 }() 165 166 // Start the state accumulation/transmission Goroutine. It will only 167 // terminate when the multiplexer is closed. 168 go m.enqueue() 169 170 // Create a timer to enforce heartbeat reception and defer its shutdown. If 171 // inbound heartbeats are not required, then just leave the timer stopped. 172 heartbeatTimeout := time.NewTimer(m.configuration.MaximumHeartbeatReceiveInterval) 173 if m.configuration.MaximumHeartbeatReceiveInterval > 0 { 174 defer heartbeatTimeout.Stop() 175 } else { 176 if !heartbeatTimeout.Stop() { 177 <-heartbeatTimeout.C 178 } 179 } 180 181 // Loop until failure or multiplexer closure. 182 for { 183 select { 184 case <-heartbeats: 185 if m.configuration.MaximumHeartbeatReceiveInterval > 0 { 186 if !heartbeatTimeout.Stop() { 187 <-heartbeatTimeout.C 188 } 189 heartbeatTimeout.Reset(m.configuration.MaximumHeartbeatReceiveInterval) 190 } 191 case err := <-readErrors: 192 m.closeWithError(fmt.Errorf("read error: %w", err)) 193 return 194 case err := <-writeErrors: 195 m.closeWithError(fmt.Errorf("write error: %w", err)) 196 return 197 case <-heartbeatTimeout.C: 198 m.closeWithError(errors.New("heartbeat timeout")) 199 return 200 case <-m.closed: 201 return 202 } 203 } 204 } 205 206 // read is the entry point for the reader Goroutine. 207 func (m *Multiplexer) read(reader Carrier, heartbeats chan<- struct{}) error { 208 // Create a buffer for reading stream data lengths, which are encoded as 209 // 16-bit unsigned integers. 210 var lengthBuffer [2]byte 211 212 // Track the range of stream identifiers used by the remote. 213 var largestOpenedInboundStreamIdentifier uint64 214 215 // Loop until failure or multiplexure closure. 216 for { 217 // Read the next message type. 218 var kind messageKind 219 if k, err := reader.ReadByte(); err != nil { 220 return fmt.Errorf("unable to read message kind: %w", err) 221 } else { 222 kind = messageKind(k) 223 } 224 225 // Ensure that the message kind is valid. 226 if kind > messageKindStreamClose { 227 return fmt.Errorf("received unknown message kind: %#02x", kind) 228 } 229 230 // If this is a multiplexer heartbeat message, then strobe the heartbeat 231 // channel and continue to the next message. 232 if kind == messageKindMultiplexerHeartbeat { 233 select { 234 case heartbeats <- struct{}{}: 235 default: 236 } 237 continue 238 } 239 240 // At this point, we know that this is a stream message, so decode the 241 // stream identifier and perform basic validation. 242 streamIdentifier, err := binary.ReadUvarint(reader) 243 if err != nil { 244 return fmt.Errorf("unable to read stream identifier (message kind %#02x): %w", kind, err) 245 } else if streamIdentifier == 0 { 246 return fmt.Errorf("zero-value stream identifier received (message kind %#02x)", kind) 247 } 248 249 // Verify that the stream identifier falls with an acceptable range, 250 // depending on its origin and the message kind, and look up the 251 // corresponding stream object, if applicable. 252 streamIdentifierIsOutbound := m.even == (streamIdentifier%2 == 0) 253 var stream *Stream 254 if kind == messageKindStreamOpen { 255 if streamIdentifierIsOutbound { 256 return errors.New("outbound stream identifier used by remote to open stream") 257 } else if streamIdentifier <= largestOpenedInboundStreamIdentifier { 258 return errors.New("remote stream identifiers not monotonically increasing") 259 } 260 largestOpenedInboundStreamIdentifier = streamIdentifier 261 } else if kind == messageKindStreamAccept && !streamIdentifierIsOutbound { 262 return errors.New("inbound stream identifier used by remote to accept stream") 263 } else { 264 inboundStreamIdentifierOutOfRange := !streamIdentifierIsOutbound && 265 streamIdentifier > largestOpenedInboundStreamIdentifier 266 if inboundStreamIdentifierOutOfRange { 267 return fmt.Errorf("message (%#02x) received for unopened inbound stream identifier", kind) 268 } 269 m.streamLock.Lock() 270 outboundStreamIdentifierOutOfRange := streamIdentifierIsOutbound && 271 m.nextOutboundStreamIdentifier != 0 && 272 streamIdentifier >= m.nextOutboundStreamIdentifier 273 if outboundStreamIdentifierOutOfRange { 274 m.streamLock.Unlock() 275 return fmt.Errorf("message (%#02x) received for unused outbound stream identifier", kind) 276 } 277 stream = m.streams[streamIdentifier] 278 m.streamLock.Unlock() 279 } 280 281 // Handle the remainder of the message based on kind. 282 if kind == messageKindStreamOpen { 283 // Decode the remote's initial receive window size. 284 windowSize, err := binary.ReadUvarint(reader) 285 if err != nil { 286 return fmt.Errorf("unable to read initial stream window size on open: %w", err) 287 } 288 289 // If there's no capacity for additional streams in the backlog, 290 // then enqueue a close message to reject the stream. 291 if len(m.pendingInboundStreamIdentifiers) == m.configuration.AcceptBacklog { 292 select { 293 case m.enqueueClose <- streamIdentifier: 294 continue 295 case <-m.closed: 296 return ErrMultiplexerClosed 297 } 298 } 299 300 // Create the local end of the stream. 301 stream := newStream(m, streamIdentifier, m.configuration.StreamReceiveWindow) 302 303 // Set the stream's initial write window. 304 stream.sendWindow = windowSize 305 if windowSize > 0 { 306 stream.sendWindowReady <- struct{}{} 307 } 308 309 // Register the stream. 310 m.streamLock.Lock() 311 m.streams[streamIdentifier] = stream 312 m.streamLock.Unlock() 313 314 // Enqueue the stream for acceptance. 315 m.pendingInboundStreamIdentifiers <- streamIdentifier 316 } else if kind == messageKindStreamAccept { 317 // Decode the remote's initial receive window size. 318 windowSize, err := binary.ReadUvarint(reader) 319 if err != nil { 320 return fmt.Errorf("unable to read initial stream window size on accept: %w", err) 321 } 322 323 // If the stream wasn't found locally, then we just have to assume 324 // that the open request was already cancelled and that a close 325 // response was already sent to the remote. In theory, there could 326 // be misbehavior here from the remote, but we have no way to track 327 // or detect it. In this case, we discard the message. 328 if stream == nil { 329 continue 330 } 331 332 // Verify that the stream wasn't already accepted or rejected. 333 if isClosed(stream.established) { 334 return errors.New("remote accepted the same stream twice") 335 } else if isClosed(stream.remoteClosed) { 336 return errors.New("remote accepted stream after closing it") 337 } 338 339 // Set the stream's initial write window. We don't need to lock the 340 // write window at this point since the stream hasn't been returned 341 // to the caller of OpenStream yet. 342 stream.sendWindow = windowSize 343 if windowSize > 0 { 344 stream.sendWindowReady <- struct{}{} 345 } 346 347 // Mark the stream as accepted. 348 close(stream.established) 349 } else if kind == messageKindStreamData { 350 // Decode the data length. 351 if _, err := io.ReadFull(reader, lengthBuffer[:]); err != nil { 352 return fmt.Errorf("unable to read data length: %w", err) 353 } 354 length := int(binary.BigEndian.Uint16(lengthBuffer[:])) 355 if length == 0 { 356 return errors.New("zero-length data received") 357 } 358 359 // If the stream wasn't found locally, then we just have to assume 360 // that it was already closed locally and deregistered. In theory, 361 // there could be misbehavior here from the remote, but we have no 362 // way to track or detect it. In this case, we discard the data. 363 if stream == nil { 364 if _, err := reader.Discard(length); err != nil { 365 return fmt.Errorf("unable to discard data: %w", err) 366 } 367 continue 368 } 369 370 // Verify that the stream has been established and isn't closed for 371 // writing or closed. 372 if !isClosed(stream.established) { 373 return errors.New("data received for partially established stream") 374 } else if isClosed(stream.remoteClosedWrite) { 375 return errors.New("data received for write-closed stream") 376 } else if isClosed(stream.remoteClosed) { 377 return errors.New("data received for closed stream") 378 } 379 380 // Record the data. 381 stream.receiveBufferLock.Lock() 382 if _, err := stream.receiveBuffer.ReadNFrom(reader, length); err != nil { 383 stream.receiveBufferLock.Unlock() 384 if err == ring.ErrBufferFull { 385 return errors.New("remote violated stream receive window") 386 } 387 return fmt.Errorf("unable to read stream data into buffer: %w", err) 388 } 389 if stream.receiveBuffer.Used() == length { 390 stream.receiveBufferReady <- struct{}{} 391 } 392 stream.receiveBufferLock.Unlock() 393 } else if kind == messageKindStreamWindowIncrement { 394 // Decode the remote's receive window size increment. 395 windowSizeIncrement, err := binary.ReadUvarint(reader) 396 if err != nil { 397 return fmt.Errorf("unable to read stream window size increment: %w", err) 398 } else if windowSizeIncrement == 0 { 399 return errors.New("zero-valued window increment received") 400 } 401 402 // If the stream wasn't found locally, then we just have to assume 403 // that it was already closed locally and deregistered. In theory, 404 // there could be misbehavior here from the remote, but we have no 405 // way to track or detect it. In this case, we discard the message. 406 if stream == nil { 407 continue 408 } 409 410 // If this is an outbound stream, then ensure that the stream is 411 // established (i.e. it's been accepted by the remote) before 412 // allowing window increments. For inbound streams, we allow 413 // adjustments to the window size before we accept the stream 414 // locally, even though we don't utilize this feature at the moment. 415 if streamIdentifierIsOutbound && !isClosed(stream.established) { 416 return errors.New("window increment received for partially established outbound stream") 417 } 418 419 // Verify that the stream isn't already closed. 420 if isClosed(stream.remoteClosed) { 421 return errors.New("window increment received for closed stream") 422 } 423 424 // Increment the window. 425 stream.sendWindowLock.Lock() 426 if stream.sendWindow == 0 { 427 stream.sendWindow = windowSizeIncrement 428 stream.sendWindowReady <- struct{}{} 429 } else { 430 if math.MaxUint64-stream.sendWindow < windowSizeIncrement { 431 stream.sendWindowLock.Unlock() 432 return errors.New("window increment overflows maximum value") 433 } 434 stream.sendWindow += windowSizeIncrement 435 } 436 stream.sendWindowLock.Unlock() 437 } else if kind == messageKindStreamCloseWrite { 438 // If the stream wasn't found locally, then we just have to assume 439 // that it was already closed locally and deregistered. In theory, 440 // there could be misbehavior here from the remote, but we have no 441 // way to track or detect it. In this case, we discard the message. 442 if stream == nil { 443 continue 444 } 445 446 // If this is an outbound stream, then ensure that the stream is 447 // established (i.e. it's been accepted by the remote) before 448 // allowing write closure. For inbound streams, we allow write 449 // closure before we accept the stream locally, even though we don't 450 // utilize this feature at the moment. 451 if streamIdentifierIsOutbound && !isClosed(stream.established) { 452 return errors.New("close write received for partially established outbound stream") 453 } 454 455 // Verify that the stream isn't already closed or closed for writes. 456 if isClosed(stream.remoteClosed) { 457 return errors.New("close write received for closed stream") 458 } else if isClosed(stream.remoteClosedWrite) { 459 return errors.New("close write received for the same stream twice") 460 } 461 462 // Signal write closure. 463 close(stream.remoteClosedWrite) 464 } else if kind == messageKindStreamClose { 465 // If the stream wasn't found locally, then we just have to assume 466 // that it was already closed locally and deregistered. In theory, 467 // there could be misbehavior here from the remote, but we have no 468 // way to track or detect it. In this case, we discard the message. 469 if stream == nil { 470 continue 471 } 472 473 // Verify that the stream isn't already closed. 474 if isClosed(stream.remoteClosed) { 475 return errors.New("close received the same stream twice") 476 } 477 478 // Signal closure. 479 close(stream.remoteClosed) 480 } else { 481 panic("unhandled message kind") 482 } 483 } 484 } 485 486 // write is the entry point for the writer Goroutine. 487 func (m *Multiplexer) write(writer Carrier) error { 488 // If outbound heartbeats are enabled, then create a ticker to regulate 489 // heartbeat transmission, defer its shutdown, and craft a reusable 490 // heartbeat message. 491 var heartbeatTicker *time.Ticker 492 var writeHeartbeat <-chan time.Time 493 var heartbeat []byte 494 if m.configuration.HeartbeatTransmitInterval > 0 { 495 heartbeatTicker = time.NewTicker(m.configuration.HeartbeatTransmitInterval) 496 defer heartbeatTicker.Stop() 497 writeHeartbeat = heartbeatTicker.C 498 heartbeat = []byte{byte(messageKindMultiplexerHeartbeat)} 499 } 500 501 // Loop until failure or multiplexer closure. 502 for { 503 select { 504 case <-writeHeartbeat: 505 if _, err := writer.Write(heartbeat); err != nil { 506 return fmt.Errorf("unable to write heartbeat: %w", err) 507 } 508 case writeBuffer := <-m.writeBufferPending: 509 if _, err := writeBuffer.WriteTo(writer); err != nil { 510 return fmt.Errorf("unable to write message buffer: %w", err) 511 } 512 m.writeBufferAvailable <- writeBuffer 513 case <-m.closed: 514 return ErrMultiplexerClosed 515 } 516 } 517 } 518 519 // enqueue is the entry point for the state accumulation/transmission Goroutine. 520 func (m *Multiplexer) enqueue() { 521 // Track pending updates. 522 windowIncrements := make(map[uint64]uint64) 523 writeCloses := make(map[uint64]bool) 524 closes := make(map[uint64]bool) 525 526 // Loop and process updates until failure. 527 for { 528 // Determine whether or not to poll for write buffer availability (based 529 // on whether or not we have any pending updates). 530 writeBufferAvailable := m.writeBufferAvailable 531 if len(windowIncrements) == 0 && len(writeCloses) == 0 && len(closes) == 0 { 532 writeBufferAvailable = nil 533 } 534 535 // Poll for a write buffer (if applicable), an update, or termination. 536 // If we get a write buffer, then write as many updates as we can. 537 select { 538 case writeBuffer := <-writeBufferAvailable: 539 for stream, amount := range windowIncrements { 540 if writeBuffer.canEncodeStreamWindowIncrement() { 541 writeBuffer.encodeStreamWindowIncrement(stream, amount) 542 delete(windowIncrements, stream) 543 } else { 544 break 545 } 546 } 547 for stream := range writeCloses { 548 if writeBuffer.canEncodeStreamCloseWrite() { 549 writeBuffer.encodeStreamCloseWrite(stream) 550 delete(writeCloses, stream) 551 } else { 552 break 553 } 554 } 555 for stream := range closes { 556 if writeBuffer.canEncodeStreamClose() { 557 writeBuffer.encodeStreamClose(stream) 558 delete(closes, stream) 559 } else { 560 break 561 } 562 } 563 m.writeBufferPending <- writeBuffer 564 case increment := <-m.enqueueWindowIncrement: 565 windowIncrements[increment.stream] = windowIncrements[increment.stream] + increment.amount 566 case stream := <-m.enqueueCloseWrite: 567 writeCloses[stream] = true 568 case stream := <-m.enqueueClose: 569 delete(windowIncrements, stream) 570 delete(writeCloses, stream) 571 closes[stream] = true 572 case <-m.closed: 573 return 574 } 575 } 576 } 577 578 // Addr implements net.Listener.Addr. 579 func (m *Multiplexer) Addr() net.Addr { 580 return &multiplexerAddress{even: m.even} 581 } 582 583 // OpenStream opens a new stream, cancelling the open operation if the provided 584 // context is cancelled, an error occurs, or the multiplexer is closed. The 585 // context must not be nil. The context only regulates the lifetime of the open 586 // operation, not the stream itself. 587 func (m *Multiplexer) OpenStream(ctx context.Context) (*Stream, error) { 588 // Create and register the local side of the stream. If we've already 589 // exhausted local stream identifiers, then we can't open a new stream. 590 m.streamLock.Lock() 591 if m.nextOutboundStreamIdentifier == 0 { 592 m.streamLock.Unlock() 593 return nil, errors.New("local stream identifiers exhausted") 594 } 595 stream := newStream(m, m.nextOutboundStreamIdentifier, m.configuration.StreamReceiveWindow) 596 m.streams[m.nextOutboundStreamIdentifier] = stream 597 if math.MaxUint64-m.nextOutboundStreamIdentifier < 2 { 598 m.nextOutboundStreamIdentifier = 0 599 } else { 600 m.nextOutboundStreamIdentifier += 2 601 } 602 m.streamLock.Unlock() 603 604 // If we fail to establish the stream, then defer its closure. We can't use 605 // the stream's established channel to check this because it could be closed 606 // by the reader Goroutine after some other error aborts the opening. 607 var sentOpenMessage, established bool 608 defer func() { 609 if !established { 610 stream.close(sentOpenMessage) 611 } 612 }() 613 614 // Write the open message and queue it for transmission. 615 select { 616 case writeBuffer := <-m.writeBufferAvailable: 617 writeBuffer.encodeOpenMessage(stream.identifier, uint64(m.configuration.StreamReceiveWindow)) 618 m.writeBufferPending <- writeBuffer 619 sentOpenMessage = true 620 case <-ctx.Done(): 621 return nil, context.Canceled 622 case <-m.closed: 623 return nil, ErrMultiplexerClosed 624 } 625 626 // Wait for stream acceptance or rejection. 627 select { 628 case <-stream.established: 629 established = true 630 return stream, nil 631 case <-stream.remoteClosed: 632 return nil, ErrStreamRejected 633 case <-ctx.Done(): 634 return nil, context.Canceled 635 case <-m.closed: 636 return nil, ErrMultiplexerClosed 637 } 638 } 639 640 // errStaleInboundStream indicates that a stale inbound stream was encountered. 641 var errStaleInboundStream = errors.New("stale inbound stream") 642 643 // acceptOneStream is the internal stream accept method. It will only attempt 644 // one accept, and will return errStaleInboundStream if the accept request fails 645 // due to a stale inbound stream. 646 func (m *Multiplexer) acceptOneStream(ctx context.Context) (*Stream, error) { 647 // Grab the oldest pending stream identifier. 648 var streamIdentifier uint64 649 select { 650 case streamIdentifier = <-m.pendingInboundStreamIdentifiers: 651 case <-ctx.Done(): 652 return nil, context.Canceled 653 case <-m.closed: 654 return nil, ErrMultiplexerClosed 655 } 656 657 // Grab the associated stream object, which is guaranteed to be non-nil. 658 m.streamLock.Lock() 659 stream := m.streams[streamIdentifier] 660 m.streamLock.Unlock() 661 662 // If we fail to establish the stream, then defer its closure. In this case 663 // (unlike the opening case) we can use the stream's established channel to 664 // check this because we're responsible for closing it. 665 defer func() { 666 if !isClosed(stream.established) { 667 stream.Close() 668 } 669 }() 670 671 // Wait for a write buffer to become available. 672 var writeBuffer *messageBuffer 673 select { 674 case writeBuffer = <-m.writeBufferAvailable: 675 case <-stream.remoteClosed: 676 return nil, errStaleInboundStream 677 case <-ctx.Done(): 678 return nil, context.Canceled 679 case <-m.closed: 680 return nil, ErrMultiplexerClosed 681 } 682 683 // Mark the stream as established. We need to do this before transmitting 684 // the accept message because the other side might start sending messages 685 // immediately and the reader Goroutine will want to confirm establishment 686 // when processing those messages. 687 close(stream.established) 688 689 // Write the accept message and queue it for transmission. 690 writeBuffer.encodeAcceptMessage(streamIdentifier, uint64(m.configuration.StreamReceiveWindow)) 691 m.writeBufferPending <- writeBuffer 692 693 // Success. 694 return stream, nil 695 } 696 697 // AcceptContext accepts an incoming stream. 698 func (m *Multiplexer) AcceptStream(ctx context.Context) (*Stream, error) { 699 // Loop until we find a pending stream that's not stale or encounter some 700 // other error. 701 for { 702 stream, err := m.acceptOneStream(ctx) 703 if err == errStaleInboundStream { 704 continue 705 } 706 return stream, err 707 } 708 } 709 710 // Accept implements net.Listener.Accept. It is implemented as a wrapper around 711 // AcceptStream and simply casts the resulting stream to a net.Conn. 712 func (m *Multiplexer) Accept() (net.Conn, error) { 713 stream, err := m.AcceptStream(context.Background()) 714 return stream, err 715 } 716 717 // Closed returns a channel that is closed when the multiplexer is closed (due 718 // to either internal failure or a manual call to Close). 719 func (m *Multiplexer) Closed() <-chan struct{} { 720 return m.closed 721 } 722 723 // InternalError returns any internal error that caused the multiplexer to 724 // close (as indicated by closure of the result of Closed). It returns nil if 725 // Close was manually invoked. 726 func (m *Multiplexer) InternalError() error { 727 m.internalErrorLock.RLock() 728 defer m.internalErrorLock.RUnlock() 729 return m.internalError 730 } 731 732 // closeWithError is the internal close method that allows for optional error 733 // reporting when closing. 734 func (m *Multiplexer) closeWithError(internalError error) (err error) { 735 m.closeOnce.Do(func() { 736 err = m.closer.Close() 737 if internalError != nil { 738 m.internalErrorLock.Lock() 739 m.internalError = internalError 740 m.internalErrorLock.Unlock() 741 } 742 close(m.closed) 743 }) 744 return 745 } 746 747 // Close implements net.Listener.Close. Only the first call to Close will have 748 // any effect. Subsequent calls will behave as no-ops and return nil errors. 749 func (m *Multiplexer) Close() error { 750 return m.closeWithError(nil) 751 }