github.com/mutagen-io/mutagen@v0.18.0-rc1/pkg/multiplexing/multiplexer.go (about)

     1  package multiplexing
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"math"
    10  	"net"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/mutagen-io/mutagen/pkg/multiplexing/ring"
    15  )
    16  
    17  var (
    18  	// ErrMultiplexerClosed is returned from operations that fail due to a
    19  	// multiplexer being closed.
    20  	ErrMultiplexerClosed = errors.New("multiplexer closed")
    21  	// ErrStreamRejected is returned from open operations that fail due to the
    22  	// remote endpoint rejecting the open request.
    23  	ErrStreamRejected = errors.New("stream rejected")
    24  )
    25  
    26  // windowIncrement is used to pass a window increment from a stream to the
    27  // multiplexer.
    28  type windowIncrement struct {
    29  	// stream is the stream identifier.
    30  	stream uint64
    31  	// amount is the increment amount.
    32  	amount uint64
    33  }
    34  
    35  // Multiplexer provides bidirectional stream multiplexing.
    36  type Multiplexer struct {
    37  	// even indicates whether or not the multiplexer uses even-numbered outbound
    38  	// stream identifiers.
    39  	even bool
    40  	// configuration is the multiplexer configuration.
    41  	configuration *Configuration
    42  
    43  	// closeOnce guards closure of closer and closed.
    44  	closeOnce sync.Once
    45  	// closer closes the underlying carrier.
    46  	closer io.Closer
    47  	// closed is closed when the underlying carrier is closed.
    48  	closed chan struct{}
    49  	// internalErrorLock guards access to internalError.
    50  	internalErrorLock sync.RWMutex
    51  	// internalError records the error associated with closure, if any.
    52  	internalError error
    53  
    54  	// streamLock guards nextOutboundStreamIdentifier and streams.
    55  	streamLock sync.Mutex
    56  	// nextOutboundStreamIdentifier is the next outbound stream identifier that
    57  	// will be used. It is set to 0 when outbound identifiers are exhausted.
    58  	nextOutboundStreamIdentifier uint64
    59  	// streams maps stream identifiers to their corresponding local stream
    60  	// objects. Stream objects perform their own deregistration when closed.
    61  	streams map[uint64]*Stream
    62  	// pendingInboundStreamIdentifiers is the backlog of pending inbound stream
    63  	// identifiers waiting to be accepted. It is written to only by the reader
    64  	// Goroutine. It has a capacity equal to the accept backlog size.
    65  	pendingInboundStreamIdentifiers chan uint64
    66  
    67  	// writeBufferAvailable is the channel where empty outbound message buffers
    68  	// are stored. If a buffer is in this channel, it is guaranteed to have
    69  	// sufficient free space to buffer any single message. Pollers on this
    70  	// channel should always poll on closed simultaneously and terminate if
    71  	// closed is closed.
    72  	writeBufferAvailable chan *messageBuffer
    73  	// writeBufferPending is the channel where non-empty outbound message
    74  	// buffers should be placed to enqueue them for transmission. Writes to this
    75  	// channel are only allowed by holders of outbound message buffers and are
    76  	// guaranteed never to block.
    77  	writeBufferPending chan *messageBuffer
    78  
    79  	// enqueueWindowIncrement enqueues transmission of a stream receive window
    80  	// increment message. The amount will be added to any pending window
    81  	// increment. This channel is unbuffered, but guaranteed to be approximately
    82  	// non-blocking as long as the multiplexer is not closed (as indicated by
    83  	// the closed channel).
    84  	enqueueWindowIncrement chan windowIncrement
    85  	// enqueueCloseWrite enqueues transmission of a stream close write message.
    86  	// It should be provided with the stream identifier. This channel is
    87  	// unbuffered, but guaranteed to be approximately non-blocking as long as
    88  	// the multiplexer is not closed (as indicated by the closed channel).
    89  	enqueueCloseWrite chan uint64
    90  	// enqueueClose enqueues transmission of a stream close message. It should
    91  	// be provided with the stream identifier. Any pending window increment or
    92  	// close write messages will be cancelled. This channel is unbuffered, but
    93  	// guaranteed to be approximately non-blocking as long as the multiplexer is
    94  	// not closed (as indicated by the closed channel).
    95  	enqueueClose chan uint64
    96  }
    97  
    98  // Multiplex creates a new multiplexer on top of an existing carrier stream. The
    99  // multiplexer takes ownership of the carrier, so it should not be used directly
   100  // after being passed to this function.
   101  //
   102  // Multiplexers are symmetric, meaning that a multiplexer at either end of the
   103  // carrier can both open and accept connections. However, a single asymmetric
   104  // parameter is required to avoid the need for negotiating stream identifiers,
   105  // so the even parameter must be set to true on one endpoint and false on the
   106  // other (using some implicit or out-of-band coordination mechanism, such as
   107  // false for client and true for server). The value of even has no observable
   108  // effect on the multiplexer, other than determining the evenness of outbound
   109  // stream identifiers.
   110  //
   111  // If configuration is nil, the default configuration will be used.
   112  func Multiplex(carrier Carrier, even bool, configuration *Configuration) *Multiplexer {
   113  	// If no configuration was provided, then use default values, otherwise
   114  	// normalize any out-of-range values provided by the caller.
   115  	if configuration == nil {
   116  		configuration = DefaultConfiguration()
   117  	} else {
   118  		configuration.normalize()
   119  	}
   120  
   121  	// Create the multiplexer.
   122  	multiplexer := &Multiplexer{
   123  		even:                            even,
   124  		configuration:                   configuration,
   125  		closer:                          carrier,
   126  		closed:                          make(chan struct{}),
   127  		streams:                         make(map[uint64]*Stream),
   128  		pendingInboundStreamIdentifiers: make(chan uint64, configuration.AcceptBacklog),
   129  		writeBufferAvailable:            make(chan *messageBuffer, configuration.WriteBufferCount),
   130  		writeBufferPending:              make(chan *messageBuffer, configuration.WriteBufferCount),
   131  		enqueueWindowIncrement:          make(chan windowIncrement),
   132  		enqueueCloseWrite:               make(chan uint64),
   133  		enqueueClose:                    make(chan uint64),
   134  	}
   135  	if even {
   136  		multiplexer.nextOutboundStreamIdentifier = 2
   137  	} else {
   138  		multiplexer.nextOutboundStreamIdentifier = 1
   139  	}
   140  	for i := 0; i < configuration.WriteBufferCount; i++ {
   141  		multiplexer.writeBufferAvailable <- newMessageBuffer()
   142  	}
   143  
   144  	// Start the multiplexer's background Goroutines.
   145  	go multiplexer.run(carrier)
   146  
   147  	// Done.
   148  	return multiplexer
   149  }
   150  
   151  // run is the primary entry point for the multiplexer's background Goroutines.
   152  func (m *Multiplexer) run(carrier Carrier) {
   153  	// Start the reader Goroutine and monitor for its termination.
   154  	heartbeats := make(chan struct{}, 1)
   155  	readErrors := make(chan error, 1)
   156  	go func() {
   157  		readErrors <- m.read(carrier, heartbeats)
   158  	}()
   159  
   160  	// Start the writer Goroutine and monitor for its termination.
   161  	writeErrors := make(chan error, 1)
   162  	go func() {
   163  		writeErrors <- m.write(carrier)
   164  	}()
   165  
   166  	// Start the state accumulation/transmission Goroutine. It will only
   167  	// terminate when the multiplexer is closed.
   168  	go m.enqueue()
   169  
   170  	// Create a timer to enforce heartbeat reception and defer its shutdown. If
   171  	// inbound heartbeats are not required, then just leave the timer stopped.
   172  	heartbeatTimeout := time.NewTimer(m.configuration.MaximumHeartbeatReceiveInterval)
   173  	if m.configuration.MaximumHeartbeatReceiveInterval > 0 {
   174  		defer heartbeatTimeout.Stop()
   175  	} else {
   176  		if !heartbeatTimeout.Stop() {
   177  			<-heartbeatTimeout.C
   178  		}
   179  	}
   180  
   181  	// Loop until failure or multiplexer closure.
   182  	for {
   183  		select {
   184  		case <-heartbeats:
   185  			if m.configuration.MaximumHeartbeatReceiveInterval > 0 {
   186  				if !heartbeatTimeout.Stop() {
   187  					<-heartbeatTimeout.C
   188  				}
   189  				heartbeatTimeout.Reset(m.configuration.MaximumHeartbeatReceiveInterval)
   190  			}
   191  		case err := <-readErrors:
   192  			m.closeWithError(fmt.Errorf("read error: %w", err))
   193  			return
   194  		case err := <-writeErrors:
   195  			m.closeWithError(fmt.Errorf("write error: %w", err))
   196  			return
   197  		case <-heartbeatTimeout.C:
   198  			m.closeWithError(errors.New("heartbeat timeout"))
   199  			return
   200  		case <-m.closed:
   201  			return
   202  		}
   203  	}
   204  }
   205  
   206  // read is the entry point for the reader Goroutine.
   207  func (m *Multiplexer) read(reader Carrier, heartbeats chan<- struct{}) error {
   208  	// Create a buffer for reading stream data lengths, which are encoded as
   209  	// 16-bit unsigned integers.
   210  	var lengthBuffer [2]byte
   211  
   212  	// Track the range of stream identifiers used by the remote.
   213  	var largestOpenedInboundStreamIdentifier uint64
   214  
   215  	// Loop until failure or multiplexure closure.
   216  	for {
   217  		// Read the next message type.
   218  		var kind messageKind
   219  		if k, err := reader.ReadByte(); err != nil {
   220  			return fmt.Errorf("unable to read message kind: %w", err)
   221  		} else {
   222  			kind = messageKind(k)
   223  		}
   224  
   225  		// Ensure that the message kind is valid.
   226  		if kind > messageKindStreamClose {
   227  			return fmt.Errorf("received unknown message kind: %#02x", kind)
   228  		}
   229  
   230  		// If this is a multiplexer heartbeat message, then strobe the heartbeat
   231  		// channel and continue to the next message.
   232  		if kind == messageKindMultiplexerHeartbeat {
   233  			select {
   234  			case heartbeats <- struct{}{}:
   235  			default:
   236  			}
   237  			continue
   238  		}
   239  
   240  		// At this point, we know that this is a stream message, so decode the
   241  		// stream identifier and perform basic validation.
   242  		streamIdentifier, err := binary.ReadUvarint(reader)
   243  		if err != nil {
   244  			return fmt.Errorf("unable to read stream identifier (message kind %#02x): %w", kind, err)
   245  		} else if streamIdentifier == 0 {
   246  			return fmt.Errorf("zero-value stream identifier received (message kind %#02x)", kind)
   247  		}
   248  
   249  		// Verify that the stream identifier falls with an acceptable range,
   250  		// depending on its origin and the message kind, and look up the
   251  		// corresponding stream object, if applicable.
   252  		streamIdentifierIsOutbound := m.even == (streamIdentifier%2 == 0)
   253  		var stream *Stream
   254  		if kind == messageKindStreamOpen {
   255  			if streamIdentifierIsOutbound {
   256  				return errors.New("outbound stream identifier used by remote to open stream")
   257  			} else if streamIdentifier <= largestOpenedInboundStreamIdentifier {
   258  				return errors.New("remote stream identifiers not monotonically increasing")
   259  			}
   260  			largestOpenedInboundStreamIdentifier = streamIdentifier
   261  		} else if kind == messageKindStreamAccept && !streamIdentifierIsOutbound {
   262  			return errors.New("inbound stream identifier used by remote to accept stream")
   263  		} else {
   264  			inboundStreamIdentifierOutOfRange := !streamIdentifierIsOutbound &&
   265  				streamIdentifier > largestOpenedInboundStreamIdentifier
   266  			if inboundStreamIdentifierOutOfRange {
   267  				return fmt.Errorf("message (%#02x) received for unopened inbound stream identifier", kind)
   268  			}
   269  			m.streamLock.Lock()
   270  			outboundStreamIdentifierOutOfRange := streamIdentifierIsOutbound &&
   271  				m.nextOutboundStreamIdentifier != 0 &&
   272  				streamIdentifier >= m.nextOutboundStreamIdentifier
   273  			if outboundStreamIdentifierOutOfRange {
   274  				m.streamLock.Unlock()
   275  				return fmt.Errorf("message (%#02x) received for unused outbound stream identifier", kind)
   276  			}
   277  			stream = m.streams[streamIdentifier]
   278  			m.streamLock.Unlock()
   279  		}
   280  
   281  		// Handle the remainder of the message based on kind.
   282  		if kind == messageKindStreamOpen {
   283  			// Decode the remote's initial receive window size.
   284  			windowSize, err := binary.ReadUvarint(reader)
   285  			if err != nil {
   286  				return fmt.Errorf("unable to read initial stream window size on open: %w", err)
   287  			}
   288  
   289  			// If there's no capacity for additional streams in the backlog,
   290  			// then enqueue a close message to reject the stream.
   291  			if len(m.pendingInboundStreamIdentifiers) == m.configuration.AcceptBacklog {
   292  				select {
   293  				case m.enqueueClose <- streamIdentifier:
   294  					continue
   295  				case <-m.closed:
   296  					return ErrMultiplexerClosed
   297  				}
   298  			}
   299  
   300  			// Create the local end of the stream.
   301  			stream := newStream(m, streamIdentifier, m.configuration.StreamReceiveWindow)
   302  
   303  			// Set the stream's initial write window.
   304  			stream.sendWindow = windowSize
   305  			if windowSize > 0 {
   306  				stream.sendWindowReady <- struct{}{}
   307  			}
   308  
   309  			// Register the stream.
   310  			m.streamLock.Lock()
   311  			m.streams[streamIdentifier] = stream
   312  			m.streamLock.Unlock()
   313  
   314  			// Enqueue the stream for acceptance.
   315  			m.pendingInboundStreamIdentifiers <- streamIdentifier
   316  		} else if kind == messageKindStreamAccept {
   317  			// Decode the remote's initial receive window size.
   318  			windowSize, err := binary.ReadUvarint(reader)
   319  			if err != nil {
   320  				return fmt.Errorf("unable to read initial stream window size on accept: %w", err)
   321  			}
   322  
   323  			// If the stream wasn't found locally, then we just have to assume
   324  			// that the open request was already cancelled and that a close
   325  			// response was already sent to the remote. In theory, there could
   326  			// be misbehavior here from the remote, but we have no way to track
   327  			// or detect it. In this case, we discard the message.
   328  			if stream == nil {
   329  				continue
   330  			}
   331  
   332  			// Verify that the stream wasn't already accepted or rejected.
   333  			if isClosed(stream.established) {
   334  				return errors.New("remote accepted the same stream twice")
   335  			} else if isClosed(stream.remoteClosed) {
   336  				return errors.New("remote accepted stream after closing it")
   337  			}
   338  
   339  			// Set the stream's initial write window. We don't need to lock the
   340  			// write window at this point since the stream hasn't been returned
   341  			// to the caller of OpenStream yet.
   342  			stream.sendWindow = windowSize
   343  			if windowSize > 0 {
   344  				stream.sendWindowReady <- struct{}{}
   345  			}
   346  
   347  			// Mark the stream as accepted.
   348  			close(stream.established)
   349  		} else if kind == messageKindStreamData {
   350  			// Decode the data length.
   351  			if _, err := io.ReadFull(reader, lengthBuffer[:]); err != nil {
   352  				return fmt.Errorf("unable to read data length: %w", err)
   353  			}
   354  			length := int(binary.BigEndian.Uint16(lengthBuffer[:]))
   355  			if length == 0 {
   356  				return errors.New("zero-length data received")
   357  			}
   358  
   359  			// If the stream wasn't found locally, then we just have to assume
   360  			// that it was already closed locally and deregistered. In theory,
   361  			// there could be misbehavior here from the remote, but we have no
   362  			// way to track or detect it. In this case, we discard the data.
   363  			if stream == nil {
   364  				if _, err := reader.Discard(length); err != nil {
   365  					return fmt.Errorf("unable to discard data: %w", err)
   366  				}
   367  				continue
   368  			}
   369  
   370  			// Verify that the stream has been established and isn't closed for
   371  			// writing or closed.
   372  			if !isClosed(stream.established) {
   373  				return errors.New("data received for partially established stream")
   374  			} else if isClosed(stream.remoteClosedWrite) {
   375  				return errors.New("data received for write-closed stream")
   376  			} else if isClosed(stream.remoteClosed) {
   377  				return errors.New("data received for closed stream")
   378  			}
   379  
   380  			// Record the data.
   381  			stream.receiveBufferLock.Lock()
   382  			if _, err := stream.receiveBuffer.ReadNFrom(reader, length); err != nil {
   383  				stream.receiveBufferLock.Unlock()
   384  				if err == ring.ErrBufferFull {
   385  					return errors.New("remote violated stream receive window")
   386  				}
   387  				return fmt.Errorf("unable to read stream data into buffer: %w", err)
   388  			}
   389  			if stream.receiveBuffer.Used() == length {
   390  				stream.receiveBufferReady <- struct{}{}
   391  			}
   392  			stream.receiveBufferLock.Unlock()
   393  		} else if kind == messageKindStreamWindowIncrement {
   394  			// Decode the remote's receive window size increment.
   395  			windowSizeIncrement, err := binary.ReadUvarint(reader)
   396  			if err != nil {
   397  				return fmt.Errorf("unable to read stream window size increment: %w", err)
   398  			} else if windowSizeIncrement == 0 {
   399  				return errors.New("zero-valued window increment received")
   400  			}
   401  
   402  			// If the stream wasn't found locally, then we just have to assume
   403  			// that it was already closed locally and deregistered. In theory,
   404  			// there could be misbehavior here from the remote, but we have no
   405  			// way to track or detect it. In this case, we discard the message.
   406  			if stream == nil {
   407  				continue
   408  			}
   409  
   410  			// If this is an outbound stream, then ensure that the stream is
   411  			// established (i.e. it's been accepted by the remote) before
   412  			// allowing window increments. For inbound streams, we allow
   413  			// adjustments to the window size before we accept the stream
   414  			// locally, even though we don't utilize this feature at the moment.
   415  			if streamIdentifierIsOutbound && !isClosed(stream.established) {
   416  				return errors.New("window increment received for partially established outbound stream")
   417  			}
   418  
   419  			// Verify that the stream isn't already closed.
   420  			if isClosed(stream.remoteClosed) {
   421  				return errors.New("window increment received for closed stream")
   422  			}
   423  
   424  			// Increment the window.
   425  			stream.sendWindowLock.Lock()
   426  			if stream.sendWindow == 0 {
   427  				stream.sendWindow = windowSizeIncrement
   428  				stream.sendWindowReady <- struct{}{}
   429  			} else {
   430  				if math.MaxUint64-stream.sendWindow < windowSizeIncrement {
   431  					stream.sendWindowLock.Unlock()
   432  					return errors.New("window increment overflows maximum value")
   433  				}
   434  				stream.sendWindow += windowSizeIncrement
   435  			}
   436  			stream.sendWindowLock.Unlock()
   437  		} else if kind == messageKindStreamCloseWrite {
   438  			// If the stream wasn't found locally, then we just have to assume
   439  			// that it was already closed locally and deregistered. In theory,
   440  			// there could be misbehavior here from the remote, but we have no
   441  			// way to track or detect it. In this case, we discard the message.
   442  			if stream == nil {
   443  				continue
   444  			}
   445  
   446  			// If this is an outbound stream, then ensure that the stream is
   447  			// established (i.e. it's been accepted by the remote) before
   448  			// allowing write closure. For inbound streams, we allow write
   449  			// closure before we accept the stream locally, even though we don't
   450  			// utilize this feature at the moment.
   451  			if streamIdentifierIsOutbound && !isClosed(stream.established) {
   452  				return errors.New("close write received for partially established outbound stream")
   453  			}
   454  
   455  			// Verify that the stream isn't already closed or closed for writes.
   456  			if isClosed(stream.remoteClosed) {
   457  				return errors.New("close write received for closed stream")
   458  			} else if isClosed(stream.remoteClosedWrite) {
   459  				return errors.New("close write received for the same stream twice")
   460  			}
   461  
   462  			// Signal write closure.
   463  			close(stream.remoteClosedWrite)
   464  		} else if kind == messageKindStreamClose {
   465  			// If the stream wasn't found locally, then we just have to assume
   466  			// that it was already closed locally and deregistered. In theory,
   467  			// there could be misbehavior here from the remote, but we have no
   468  			// way to track or detect it. In this case, we discard the message.
   469  			if stream == nil {
   470  				continue
   471  			}
   472  
   473  			// Verify that the stream isn't already closed.
   474  			if isClosed(stream.remoteClosed) {
   475  				return errors.New("close received the same stream twice")
   476  			}
   477  
   478  			// Signal closure.
   479  			close(stream.remoteClosed)
   480  		} else {
   481  			panic("unhandled message kind")
   482  		}
   483  	}
   484  }
   485  
   486  // write is the entry point for the writer Goroutine.
   487  func (m *Multiplexer) write(writer Carrier) error {
   488  	// If outbound heartbeats are enabled, then create a ticker to regulate
   489  	// heartbeat transmission, defer its shutdown, and craft a reusable
   490  	// heartbeat message.
   491  	var heartbeatTicker *time.Ticker
   492  	var writeHeartbeat <-chan time.Time
   493  	var heartbeat []byte
   494  	if m.configuration.HeartbeatTransmitInterval > 0 {
   495  		heartbeatTicker = time.NewTicker(m.configuration.HeartbeatTransmitInterval)
   496  		defer heartbeatTicker.Stop()
   497  		writeHeartbeat = heartbeatTicker.C
   498  		heartbeat = []byte{byte(messageKindMultiplexerHeartbeat)}
   499  	}
   500  
   501  	// Loop until failure or multiplexer closure.
   502  	for {
   503  		select {
   504  		case <-writeHeartbeat:
   505  			if _, err := writer.Write(heartbeat); err != nil {
   506  				return fmt.Errorf("unable to write heartbeat: %w", err)
   507  			}
   508  		case writeBuffer := <-m.writeBufferPending:
   509  			if _, err := writeBuffer.WriteTo(writer); err != nil {
   510  				return fmt.Errorf("unable to write message buffer: %w", err)
   511  			}
   512  			m.writeBufferAvailable <- writeBuffer
   513  		case <-m.closed:
   514  			return ErrMultiplexerClosed
   515  		}
   516  	}
   517  }
   518  
   519  // enqueue is the entry point for the state accumulation/transmission Goroutine.
   520  func (m *Multiplexer) enqueue() {
   521  	// Track pending updates.
   522  	windowIncrements := make(map[uint64]uint64)
   523  	writeCloses := make(map[uint64]bool)
   524  	closes := make(map[uint64]bool)
   525  
   526  	// Loop and process updates until failure.
   527  	for {
   528  		// Determine whether or not to poll for write buffer availability (based
   529  		// on whether or not we have any pending updates).
   530  		writeBufferAvailable := m.writeBufferAvailable
   531  		if len(windowIncrements) == 0 && len(writeCloses) == 0 && len(closes) == 0 {
   532  			writeBufferAvailable = nil
   533  		}
   534  
   535  		// Poll for a write buffer (if applicable), an update, or termination.
   536  		// If we get a write buffer, then write as many updates as we can.
   537  		select {
   538  		case writeBuffer := <-writeBufferAvailable:
   539  			for stream, amount := range windowIncrements {
   540  				if writeBuffer.canEncodeStreamWindowIncrement() {
   541  					writeBuffer.encodeStreamWindowIncrement(stream, amount)
   542  					delete(windowIncrements, stream)
   543  				} else {
   544  					break
   545  				}
   546  			}
   547  			for stream := range writeCloses {
   548  				if writeBuffer.canEncodeStreamCloseWrite() {
   549  					writeBuffer.encodeStreamCloseWrite(stream)
   550  					delete(writeCloses, stream)
   551  				} else {
   552  					break
   553  				}
   554  			}
   555  			for stream := range closes {
   556  				if writeBuffer.canEncodeStreamClose() {
   557  					writeBuffer.encodeStreamClose(stream)
   558  					delete(closes, stream)
   559  				} else {
   560  					break
   561  				}
   562  			}
   563  			m.writeBufferPending <- writeBuffer
   564  		case increment := <-m.enqueueWindowIncrement:
   565  			windowIncrements[increment.stream] = windowIncrements[increment.stream] + increment.amount
   566  		case stream := <-m.enqueueCloseWrite:
   567  			writeCloses[stream] = true
   568  		case stream := <-m.enqueueClose:
   569  			delete(windowIncrements, stream)
   570  			delete(writeCloses, stream)
   571  			closes[stream] = true
   572  		case <-m.closed:
   573  			return
   574  		}
   575  	}
   576  }
   577  
   578  // Addr implements net.Listener.Addr.
   579  func (m *Multiplexer) Addr() net.Addr {
   580  	return &multiplexerAddress{even: m.even}
   581  }
   582  
   583  // OpenStream opens a new stream, cancelling the open operation if the provided
   584  // context is cancelled, an error occurs, or the multiplexer is closed. The
   585  // context must not be nil. The context only regulates the lifetime of the open
   586  // operation, not the stream itself.
   587  func (m *Multiplexer) OpenStream(ctx context.Context) (*Stream, error) {
   588  	// Create and register the local side of the stream. If we've already
   589  	// exhausted local stream identifiers, then we can't open a new stream.
   590  	m.streamLock.Lock()
   591  	if m.nextOutboundStreamIdentifier == 0 {
   592  		m.streamLock.Unlock()
   593  		return nil, errors.New("local stream identifiers exhausted")
   594  	}
   595  	stream := newStream(m, m.nextOutboundStreamIdentifier, m.configuration.StreamReceiveWindow)
   596  	m.streams[m.nextOutboundStreamIdentifier] = stream
   597  	if math.MaxUint64-m.nextOutboundStreamIdentifier < 2 {
   598  		m.nextOutboundStreamIdentifier = 0
   599  	} else {
   600  		m.nextOutboundStreamIdentifier += 2
   601  	}
   602  	m.streamLock.Unlock()
   603  
   604  	// If we fail to establish the stream, then defer its closure. We can't use
   605  	// the stream's established channel to check this because it could be closed
   606  	// by the reader Goroutine after some other error aborts the opening.
   607  	var sentOpenMessage, established bool
   608  	defer func() {
   609  		if !established {
   610  			stream.close(sentOpenMessage)
   611  		}
   612  	}()
   613  
   614  	// Write the open message and queue it for transmission.
   615  	select {
   616  	case writeBuffer := <-m.writeBufferAvailable:
   617  		writeBuffer.encodeOpenMessage(stream.identifier, uint64(m.configuration.StreamReceiveWindow))
   618  		m.writeBufferPending <- writeBuffer
   619  		sentOpenMessage = true
   620  	case <-ctx.Done():
   621  		return nil, context.Canceled
   622  	case <-m.closed:
   623  		return nil, ErrMultiplexerClosed
   624  	}
   625  
   626  	// Wait for stream acceptance or rejection.
   627  	select {
   628  	case <-stream.established:
   629  		established = true
   630  		return stream, nil
   631  	case <-stream.remoteClosed:
   632  		return nil, ErrStreamRejected
   633  	case <-ctx.Done():
   634  		return nil, context.Canceled
   635  	case <-m.closed:
   636  		return nil, ErrMultiplexerClosed
   637  	}
   638  }
   639  
   640  // errStaleInboundStream indicates that a stale inbound stream was encountered.
   641  var errStaleInboundStream = errors.New("stale inbound stream")
   642  
   643  // acceptOneStream is the internal stream accept method. It will only attempt
   644  // one accept, and will return errStaleInboundStream if the accept request fails
   645  // due to a stale inbound stream.
   646  func (m *Multiplexer) acceptOneStream(ctx context.Context) (*Stream, error) {
   647  	// Grab the oldest pending stream identifier.
   648  	var streamIdentifier uint64
   649  	select {
   650  	case streamIdentifier = <-m.pendingInboundStreamIdentifiers:
   651  	case <-ctx.Done():
   652  		return nil, context.Canceled
   653  	case <-m.closed:
   654  		return nil, ErrMultiplexerClosed
   655  	}
   656  
   657  	// Grab the associated stream object, which is guaranteed to be non-nil.
   658  	m.streamLock.Lock()
   659  	stream := m.streams[streamIdentifier]
   660  	m.streamLock.Unlock()
   661  
   662  	// If we fail to establish the stream, then defer its closure. In this case
   663  	// (unlike the opening case) we can use the stream's established channel to
   664  	// check this because we're responsible for closing it.
   665  	defer func() {
   666  		if !isClosed(stream.established) {
   667  			stream.Close()
   668  		}
   669  	}()
   670  
   671  	// Wait for a write buffer to become available.
   672  	var writeBuffer *messageBuffer
   673  	select {
   674  	case writeBuffer = <-m.writeBufferAvailable:
   675  	case <-stream.remoteClosed:
   676  		return nil, errStaleInboundStream
   677  	case <-ctx.Done():
   678  		return nil, context.Canceled
   679  	case <-m.closed:
   680  		return nil, ErrMultiplexerClosed
   681  	}
   682  
   683  	// Mark the stream as established. We need to do this before transmitting
   684  	// the accept message because the other side might start sending messages
   685  	// immediately and the reader Goroutine will want to confirm establishment
   686  	// when processing those messages.
   687  	close(stream.established)
   688  
   689  	// Write the accept message and queue it for transmission.
   690  	writeBuffer.encodeAcceptMessage(streamIdentifier, uint64(m.configuration.StreamReceiveWindow))
   691  	m.writeBufferPending <- writeBuffer
   692  
   693  	// Success.
   694  	return stream, nil
   695  }
   696  
   697  // AcceptContext accepts an incoming stream.
   698  func (m *Multiplexer) AcceptStream(ctx context.Context) (*Stream, error) {
   699  	// Loop until we find a pending stream that's not stale or encounter some
   700  	// other error.
   701  	for {
   702  		stream, err := m.acceptOneStream(ctx)
   703  		if err == errStaleInboundStream {
   704  			continue
   705  		}
   706  		return stream, err
   707  	}
   708  }
   709  
   710  // Accept implements net.Listener.Accept. It is implemented as a wrapper around
   711  // AcceptStream and simply casts the resulting stream to a net.Conn.
   712  func (m *Multiplexer) Accept() (net.Conn, error) {
   713  	stream, err := m.AcceptStream(context.Background())
   714  	return stream, err
   715  }
   716  
   717  // Closed returns a channel that is closed when the multiplexer is closed (due
   718  // to either internal failure or a manual call to Close).
   719  func (m *Multiplexer) Closed() <-chan struct{} {
   720  	return m.closed
   721  }
   722  
   723  // InternalError returns any internal error that caused the multiplexer to
   724  // close (as indicated by closure of the result of Closed). It returns nil if
   725  // Close was manually invoked.
   726  func (m *Multiplexer) InternalError() error {
   727  	m.internalErrorLock.RLock()
   728  	defer m.internalErrorLock.RUnlock()
   729  	return m.internalError
   730  }
   731  
   732  // closeWithError is the internal close method that allows for optional error
   733  // reporting when closing.
   734  func (m *Multiplexer) closeWithError(internalError error) (err error) {
   735  	m.closeOnce.Do(func() {
   736  		err = m.closer.Close()
   737  		if internalError != nil {
   738  			m.internalErrorLock.Lock()
   739  			m.internalError = internalError
   740  			m.internalErrorLock.Unlock()
   741  		}
   742  		close(m.closed)
   743  	})
   744  	return
   745  }
   746  
   747  // Close implements net.Listener.Close. Only the first call to Close will have
   748  // any effect. Subsequent calls will behave as no-ops and return nil errors.
   749  func (m *Multiplexer) Close() error {
   750  	return m.closeWithError(nil)
   751  }