tractor.dev/toolkit-go@v0.0.0-20241010005851-214d91207d07/duplex/mux/channel.go (about)

     1  package mux
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"sync"
     8  
     9  	"tractor.dev/toolkit-go/duplex/mux/frame"
    10  )
    11  
    12  type channelDirection uint8
    13  
    14  const (
    15  	channelInbound channelDirection = iota
    16  	channelOutbound
    17  )
    18  
    19  func min(a uint32, b int) uint32 {
    20  	if a < uint32(b) {
    21  		return a
    22  	}
    23  	return uint32(b)
    24  }
    25  
    26  type Channel interface {
    27  	io.ReadWriteCloser
    28  	ID() uint32
    29  	CloseWrite() error
    30  }
    31  
    32  // channel is an implementation of the Channel interface that works
    33  // with the session class.
    34  type channel struct {
    35  
    36  	// R/O after creation
    37  	localId, remoteId uint32
    38  
    39  	// maxIncomingPayload and maxRemotePayload are the maximum
    40  	// payload sizes of normal and extended data packets for
    41  	// receiving and sending, respectively. The wire packet will
    42  	// be 9 or 13 bytes larger (excluding encryption overhead).
    43  	maxIncomingPayload uint32
    44  	maxRemotePayload   uint32
    45  
    46  	session *session
    47  
    48  	// direction contains either channelOutbound, for channels created
    49  	// locally, or channelInbound, for channels created by the peer.
    50  	direction channelDirection
    51  
    52  	// Pending internal channel messages.
    53  	msg chan frame.Message
    54  
    55  	sentEOF bool
    56  
    57  	// thread-safe data
    58  	remoteWin window
    59  	pending   *buffer
    60  
    61  	// windowMu protects myWindow, the flow-control window.
    62  	windowMu sync.Mutex
    63  	myWindow uint32
    64  
    65  	// writeMu serializes calls to session.conn.Write() and
    66  	// protects sentClose and packetPool. This mutex must be
    67  	// different from windowMu, as writePacket can block if there
    68  	// is a key exchange pending.
    69  	writeMu   sync.Mutex
    70  	sentClose bool
    71  
    72  	// packet buffer for writing
    73  	packetBuf []byte
    74  }
    75  
    76  // ID returns the unique identifier of this channel
    77  // within the session
    78  func (ch *channel) ID() uint32 {
    79  	return ch.localId
    80  }
    81  
    82  // CloseWrite signals the end of sending data.
    83  // The other side may still send data
    84  func (ch *channel) CloseWrite() error {
    85  	ch.sentEOF = true
    86  	return ch.send(frame.EOFMessage{
    87  		ChannelID: ch.remoteId})
    88  }
    89  
    90  // Close signals end of channel use. No data may be sent after this
    91  // call.
    92  func (ch *channel) Close() error {
    93  	return ch.send(frame.CloseMessage{
    94  		ChannelID: ch.remoteId})
    95  }
    96  
    97  // Write writes len(data) bytes to the channel.
    98  func (ch *channel) Write(data []byte) (n int, err error) {
    99  	if ch.sentEOF {
   100  		return 0, io.EOF
   101  	}
   102  
   103  	for len(data) > 0 {
   104  		space := min(ch.maxRemotePayload, len(data))
   105  		if space, err = ch.remoteWin.reserve(space); err != nil {
   106  			return n, err
   107  		}
   108  
   109  		toSend := data[:space]
   110  
   111  		if err = ch.session.enc.Encode(frame.DataMessage{
   112  			ChannelID: ch.remoteId,
   113  			Length:    uint32(len(toSend)),
   114  			Data:      toSend,
   115  		}); err != nil {
   116  			return n, err
   117  		}
   118  
   119  		n += len(toSend)
   120  		data = data[len(toSend):]
   121  	}
   122  
   123  	return n, err
   124  }
   125  
   126  // Read reads up to len(data) bytes from the channel.
   127  func (c *channel) Read(data []byte) (n int, err error) {
   128  	n, err = c.pending.Read(data)
   129  
   130  	if n > 0 {
   131  		err = c.adjustWindow(uint32(n))
   132  		// sendWindowAdjust can return io.EOF if the remote
   133  		// peer has closed the connection, however we want to
   134  		// defer forwarding io.EOF to the caller of Read until
   135  		// the buffer has been drained.
   136  		if n > 0 && err == io.EOF {
   137  			err = nil
   138  		}
   139  	}
   140  	return n, err
   141  }
   142  
   143  // sends writes a message frame. If the message is a channel close, it updates
   144  // sentClose. This method takes the lock c.writeMu.
   145  func (ch *channel) send(msg frame.Message) error {
   146  	ch.writeMu.Lock()
   147  	defer ch.writeMu.Unlock()
   148  
   149  	if ch.sentClose {
   150  		return io.EOF
   151  	}
   152  
   153  	if _, ok := msg.(frame.CloseMessage); ok {
   154  		ch.sentClose = true
   155  	}
   156  
   157  	return ch.session.enc.Encode(msg)
   158  }
   159  
   160  func (c *channel) adjustWindow(n uint32) error {
   161  	c.windowMu.Lock()
   162  	// Since myWindow is managed on our side, and can never exceed
   163  	// the initial window setting, we don't worry about overflow.
   164  	c.myWindow += uint32(n)
   165  	c.windowMu.Unlock()
   166  	return c.send(frame.WindowAdjustMessage{
   167  		ChannelID:       c.remoteId,
   168  		AdditionalBytes: uint32(n),
   169  	})
   170  }
   171  
   172  func (c *channel) close() {
   173  	c.pending.eof()
   174  	close(c.msg)
   175  	c.writeMu.Lock()
   176  	// This is not necessary for a normal channel teardown, but if
   177  	// there was another error, it is.
   178  	c.sentClose = true
   179  	c.writeMu.Unlock()
   180  	// Unblock writers.
   181  	c.remoteWin.close()
   182  }
   183  
   184  // responseMessageReceived is called when a success or failure message is
   185  // received on a channel to check that such a message is reasonable for the
   186  // given channel.
   187  func (ch *channel) responseMessageReceived() error {
   188  	if ch.direction == channelInbound {
   189  		return errors.New("qmux: channel response message received on inbound channel")
   190  	}
   191  	return nil
   192  }
   193  
   194  func (ch *channel) handle(msg frame.Message) error {
   195  	switch m := msg.(type) {
   196  	case *frame.DataMessage:
   197  		return ch.handleData(m)
   198  
   199  	case *frame.CloseMessage:
   200  		ch.send(frame.CloseMessage{
   201  			ChannelID: ch.remoteId,
   202  		})
   203  		ch.session.chans.remove(ch.localId)
   204  		ch.close()
   205  		return nil
   206  
   207  	case *frame.EOFMessage:
   208  		ch.pending.eof()
   209  		return nil
   210  
   211  	case *frame.WindowAdjustMessage:
   212  		if !ch.remoteWin.add(m.AdditionalBytes) {
   213  			return fmt.Errorf("qmux: invalid window update for %d bytes", m.AdditionalBytes)
   214  		}
   215  		return nil
   216  
   217  	case *frame.OpenConfirmMessage:
   218  		if err := ch.responseMessageReceived(); err != nil {
   219  			return err
   220  		}
   221  		if m.MaxPacketSize < minPacketLength || m.MaxPacketSize > maxPacketLength {
   222  			return fmt.Errorf("qmux: invalid MaxPacketSize %d from peer", m.MaxPacketSize)
   223  		}
   224  		ch.remoteId = m.SenderID
   225  		ch.maxRemotePayload = m.MaxPacketSize
   226  		ch.remoteWin.add(m.WindowSize)
   227  		ch.msg <- m
   228  		return nil
   229  
   230  	case *frame.OpenFailureMessage:
   231  		if err := ch.responseMessageReceived(); err != nil {
   232  			return err
   233  		}
   234  		ch.session.chans.remove(m.ChannelID)
   235  		ch.msg <- m
   236  		return nil
   237  
   238  	default:
   239  		return fmt.Errorf("qmux: invalid channel message %v", msg)
   240  	}
   241  }
   242  
   243  func (ch *channel) handleData(msg *frame.DataMessage) error {
   244  	if msg.Length > ch.maxIncomingPayload {
   245  		// TODO(hanwen): should send Disconnect?
   246  		return errors.New("qmux: incoming packet exceeds maximum payload size")
   247  	}
   248  
   249  	if msg.Length != uint32(len(msg.Data)) {
   250  		return errors.New("qmux: wrong packet length")
   251  	}
   252  
   253  	ch.windowMu.Lock()
   254  	if ch.myWindow < msg.Length {
   255  		ch.windowMu.Unlock()
   256  		// TODO(hanwen): should send Disconnect with reason?
   257  		return errors.New("qmux: remote side wrote too much")
   258  	}
   259  	ch.myWindow -= msg.Length
   260  	ch.windowMu.Unlock()
   261  
   262  	ch.pending.write(msg.Data)
   263  	return nil
   264  }