tractor.dev/toolkit-go@v0.0.0-20241010005851-214d91207d07/duplex/mux/session.go (about)

     1  package mux
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"sync"
     9  	"time"
    10  
    11  	"tractor.dev/toolkit-go/duplex/mux/frame"
    12  )
    13  
    14  const (
    15  	minPacketLength = 9
    16  	maxPacketLength = 1 << 31
    17  
    18  	// channelMaxPacket contains the maximum number of bytes that will be
    19  	// sent in a single packet.
    20  	channelMaxPacket = 1 << 24 // ~16MB, arbitrary
    21  	// We follow OpenSSH here.
    22  	channelWindowSize = 64 * channelMaxPacket
    23  
    24  	// chanSize sets the amount of buffering qmux connections. This is
    25  	// primarily for testing: setting chanSize=0 uncovers deadlocks more
    26  	// quickly.
    27  	chanSize = 16
    28  )
    29  
    30  var (
    31  	// timeout for queuing a new channel to be `Accept`ed
    32  	// use a `var` so that this can be overridden in tests
    33  	openTimeout = 30 * time.Second
    34  )
    35  
    36  // Session is a bi-directional channel muxing session on a given transport.
    37  type Session interface {
    38  	io.Closer
    39  	Accept() (Channel, error)
    40  	Open(ctx context.Context) (Channel, error)
    41  	Wait() error
    42  }
    43  
    44  type session struct {
    45  	t     io.ReadWriteCloser
    46  	chans chanList
    47  
    48  	enc *frame.Encoder
    49  	dec *frame.Decoder
    50  
    51  	inbox chan Channel
    52  
    53  	errCond *sync.Cond
    54  	err     error
    55  	closeCh chan bool
    56  }
    57  
    58  // NewSession returns a session that runs over the given transport.
    59  func New(t io.ReadWriteCloser) Session {
    60  	if t == nil {
    61  		return nil
    62  	}
    63  	s := &session{
    64  		t:       t,
    65  		enc:     frame.NewEncoder(t),
    66  		dec:     frame.NewDecoder(t),
    67  		inbox:   make(chan Channel),
    68  		errCond: sync.NewCond(new(sync.Mutex)),
    69  		closeCh: make(chan bool, 1),
    70  	}
    71  	go s.loop()
    72  	return s
    73  }
    74  
    75  // Close closes the underlying transport.
    76  func (s *session) Close() error {
    77  	s.t.Close()
    78  	return nil
    79  }
    80  
    81  // Wait blocks until the transport has shut down, and returns the
    82  // error causing the shutdown.
    83  func (s *session) Wait() error {
    84  	s.errCond.L.Lock()
    85  	defer s.errCond.L.Unlock()
    86  	for s.err == nil {
    87  		s.errCond.Wait()
    88  	}
    89  	return s.err
    90  }
    91  
    92  // Accept waits for and returns the next incoming channel.
    93  func (s *session) Accept() (Channel, error) {
    94  	select {
    95  	case ch := <-s.inbox:
    96  		return ch, nil
    97  	case <-s.closeCh:
    98  		return nil, io.EOF
    99  	}
   100  }
   101  
   102  // Open establishes a new channel with the other end.
   103  func (s *session) Open(ctx context.Context) (Channel, error) {
   104  	ch := s.newChannel(channelOutbound)
   105  	ch.maxIncomingPayload = channelMaxPacket
   106  
   107  	if err := s.enc.Encode(frame.OpenMessage{
   108  		WindowSize:    ch.myWindow,
   109  		MaxPacketSize: ch.maxIncomingPayload,
   110  		SenderID:      ch.localId,
   111  	}); err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	var m frame.Message
   116  
   117  	select {
   118  	case <-ctx.Done():
   119  		return nil, ctx.Err()
   120  	case m = <-ch.msg:
   121  		if m == nil {
   122  			// channel was closed before open got a response,
   123  			// typically meaning the session/conn was closed.
   124  			return nil, net.ErrClosed
   125  		}
   126  	}
   127  
   128  	switch msg := m.(type) {
   129  	case *frame.OpenConfirmMessage:
   130  		return ch, nil
   131  	case *frame.OpenFailureMessage:
   132  		return nil, fmt.Errorf("qmux: channel open failed on remote side")
   133  	default:
   134  		return nil, fmt.Errorf("qmux: unexpected packet in response to channel open: %v", msg)
   135  	}
   136  }
   137  
   138  func (s *session) newChannel(direction channelDirection) *channel {
   139  	ch := &channel{
   140  		remoteWin: window{Cond: sync.NewCond(new(sync.Mutex))},
   141  		myWindow:  channelWindowSize,
   142  		pending:   newBuffer(),
   143  		direction: direction,
   144  		msg:       make(chan frame.Message, chanSize),
   145  		session:   s,
   146  		packetBuf: make([]byte, 0),
   147  	}
   148  	ch.localId = s.chans.add(ch)
   149  	return ch
   150  }
   151  
   152  // loop runs the connection machine. It will process packets until an
   153  // error is encountered. To synchronize on loop exit, use session.Wait.
   154  func (s *session) loop() {
   155  	var err error
   156  	for err == nil {
   157  		err = s.onePacket()
   158  	}
   159  
   160  	for _, ch := range s.chans.dropAll() {
   161  		ch.close()
   162  	}
   163  
   164  	s.t.Close()
   165  	s.closeCh <- true
   166  
   167  	s.errCond.L.Lock()
   168  	s.err = err
   169  	s.errCond.Broadcast()
   170  	s.errCond.L.Unlock()
   171  }
   172  
   173  // onePacket reads and processes one packet.
   174  func (s *session) onePacket() error {
   175  	var err error
   176  	var msg frame.Message
   177  
   178  	msg, err = s.dec.Decode()
   179  	if err != nil {
   180  		return err
   181  	}
   182  
   183  	id, isChan := msg.Channel()
   184  	if !isChan {
   185  		return s.handleOpen(msg.(*frame.OpenMessage))
   186  	}
   187  
   188  	ch := s.chans.getChan(id)
   189  	if ch == nil {
   190  		return fmt.Errorf("qmux: invalid channel %d", id)
   191  	}
   192  
   193  	return ch.handle(msg)
   194  }
   195  
   196  // handleChannelOpen schedules a channel to be Accept()ed.
   197  func (s *session) handleOpen(msg *frame.OpenMessage) error {
   198  	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > maxPacketLength {
   199  		return s.enc.Encode(frame.OpenFailureMessage{
   200  			ChannelID: msg.SenderID,
   201  		})
   202  	}
   203  
   204  	c := s.newChannel(channelInbound)
   205  	c.remoteId = msg.SenderID
   206  	c.maxRemotePayload = msg.MaxPacketSize
   207  	c.remoteWin.add(msg.WindowSize)
   208  	c.maxIncomingPayload = channelMaxPacket
   209  	t := time.NewTimer(openTimeout)
   210  	defer t.Stop()
   211  	select {
   212  	case s.inbox <- c:
   213  		return s.enc.Encode(frame.OpenConfirmMessage{
   214  			ChannelID:     c.remoteId,
   215  			SenderID:      c.localId,
   216  			WindowSize:    c.myWindow,
   217  			MaxPacketSize: c.maxIncomingPayload,
   218  		})
   219  	case <-t.C:
   220  		return s.enc.Encode(frame.OpenFailureMessage{
   221  			ChannelID: msg.SenderID,
   222  		})
   223  	}
   224  }